# Coding a Transformer from Scratch

### Cristiano De Nobili - My Contacts
For any questions or doubts you can find my contacts here:

<p align="center">

[<img src="https://img.freepik.com/premium-vector/linkedin-logo_578229-227.jpg?w=1060" width="25">](https://www.linkedin.com/in/cristiano-de-nobili/) [<img src="https://1.bp.blogspot.com/-Rwqcet_SHbk/T8_acMUmlmI/AAAAAAAAGgw/KD_fx__8Q4w/s1600/Twitter+bird.png" width="30">](https://twitter.com/denocris)        

</p>

or here (https://denocris.com).




Some refs:

* Original Paper: [Attention is all you need](https://arxiv.org/pdf/1706.03762.pdf);

* [The Illustrated Transformer - Jay Alammar.](http://jalammar.github.io/illustrated-transformer/)


Disclaimers:

We will train in a standard way a Transformer Model. This is not BERT, which is a collection of Transformer layers. BERT is trained according to the Masked Language Model (MLM) paradigm.

In [None]:
%%capture
!pip install torchtext==0.6.0

# After the installation, restart the session

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


#from torchtext.legacy import data, datasets

from torchtext import data, datasets
from torchtext import vocab
import numpy as np
import random, tqdm, sys, math, gzip
from torchsummary import summary

from torch.utils.tensorboard import SummaryWriter

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [None]:
!nvidia-smi

Mon May 20 11:41:12 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P8              11W /  70W |      3MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

### Multi-head Attention Mechanism, step by step

First, let's set some hyperparameters. To keep it simple we choose small size hyperparameters.

In [None]:
emb = 128 # embedding dimension (BERT like models 768)
h = 8 # number of heads (BERT has 12 heads)

batch_size = 4
sentence_length = 21 # Context Length, 512 for BERT

Some fake random data with proper dimensions

In [None]:
x = torch.rand(batch_size, 21, emb)

b, t, e = x.size()

In [None]:
x.size()

torch.Size([4, 21, 128])

Instantiate linear transformations for query, key and values. Each transformation will act on the input vector x.

In [None]:
tokeys    = nn.Linear(emb, emb, bias=False) # W_key
toqueries = nn.Linear(emb, emb, bias=False) # W_query
tovalues  = nn.Linear(emb, emb, bias=False) # W_value

Generate queries, keys and values. We first compute the k/q/v's on the whole embedding vectors, and then split into the different heads.

In [None]:
keys    = tokeys(x) # W_key x
queries = toqueries(x)
values  = tovalues(x)

In [None]:
print(keys.size())

torch.Size([4, 21, 128])


Implement now multi-head attention (the ligther version), splitting into the different heads.

In [None]:
s = e // h # 128 / 8

keys    = keys.view(b, t, h, s)
queries = queries.view(b, t, h, s)
values  = values.view(b, t, h, s)

print(keys.size())

torch.Size([4, 21, 8, 16])


In [None]:
keys.transpose(1, 2).size()

torch.Size([4, 8, 21, 16])

In [None]:
keys.transpose(1, 2).contiguous().view(b * h, t, s).size()

torch.Size([32, 21, 16])

We need now to compute the dot products. This is the same operation for every head, so we fold the heads into the batch dimension.

In [None]:
keys = keys.transpose(1, 2).contiguous().view(b * h, t, s)
queries = queries.transpose(1, 2).contiguous().view(b * h, t, s)
values = values.transpose(1, 2).contiguous().view(b * h, t, s)

# contiguous():  it actually makes a copy of the tensor such that the order of
# its elements in memory is the same as if it had been created from scratch with the same data.
# transpose(1, 2) doesn't generate a new tensor with a new layout, it just
# modifies meta information in the Tensor object so that the offset and stride describe the desired new shape.
# https://discuss.pytorch.org/t/contigious-vs-non-contigious-tensor/30107

keys.size()

torch.Size([32, 21, 16])

Perform dot products

In [None]:
print(queries.size())

print(keys.transpose(1, 2).size())

#Let's compute the attention matrix
attn_scores = torch.bmm(queries, keys.transpose(1, 2)).size()  # batch matrix-matrix product

print(attn_scores)

torch.Size([32, 21, 16])
torch.Size([32, 16, 21])
torch.Size([32, 21, 21])


💡 Note: `torch.bmm` and the symbol `@` are the same thing. You can check easily:

```
mat1 = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res_1 = torch.bmm(mat1, mat2)
res_2 = mat1 @ mat2
```
In particular:
* `torch.bmm`: Performs a batch matrix-matrix product of 3D tensors `mat1.size() = [b,n,m]` and `mat2.size() = [b,m,p]`, outputting `res.size() = [b,n,p]`. Here is the [documentation](https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm).
* symbol `@`: The matrix multiplication(s) are done between the last two dimensions. The remaining dimensions are broadcast and batched.




Just for completeness, below the implementation of the original multi-head attention (which is wide and computationally more intensive).

In [None]:
emb = 128
h = 8

x = torch.rand(4, 21, emb)

b, t, e = x.size()

tokeys    = nn.Linear(emb, emb * h, bias=False)
toqueries = nn.Linear(emb, emb * h, bias=False)
tovalues  = nn.Linear(emb, emb * h, bias=False)

keys    = tokeys(x)
queries = toqueries(x)
values  = tovalues(x)

print(keys.size())

keys    = keys.view(b, t, h, e)
queries = queries.view(b, t, h, e)
values  = values.view(b, t, h, e)

print(keys.size())


torch.Size([4, 21, 1024])
torch.Size([4, 21, 8, 128])


### Model Definition

Let us collect everything and define the self-attention class

In [None]:
class MHSelfAttention(nn.Module):
    """
    Multi-head self attention.
    """

    def __init__(self, emb, heads=8):
        """
        :param emb:
        :param heads:
        :param mask:
        """
        super().__init__()

        assert emb % heads == 0, f'Embedding dimension ({emb}) should be divisible by nr. of heads ({heads})'

        self.emb = emb
        self.heads = heads

        #s = emb // heads
        # - We will break the embedding into `heads` chunks and feed each to a different attention head

        self.tokeys    = nn.Linear(emb, emb, bias=False) # W_key
        self.toqueries = nn.Linear(emb, emb, bias=False) # W_query
        self.tovalues  = nn.Linear(emb, emb, bias=False) # W_value

        self.unifyheads = nn.Linear(emb, emb)

    def forward(self, x):

        b, t, e = x.size()
        h = self.heads
        assert e == self.emb, f'Input embedding dim ({e}) should match layer embedding dim ({self.emb})'

        s = e // h

        # We first compute the k/q/v's on the whole embedding vectors, and then split into the different heads.

        keys    = self.tokeys(x)
        queries = self.toqueries(x)
        values  = self.tovalues(x)

        # Split into the different heads.

        keys    = keys.view(b, t, h, s)
        queries = queries.view(b, t, h, s)
        values  = values.view(b, t, h, s)

        # Compute scaled dot-product self-attention

        # Fold heads into the batch dimension
        # When you call contiguous(), it actually makes a copy of the tensor
        # such that the order of its elements in memory is the same as if it had been created from scratch with the same data.
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, s)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, s)
        values = values.transpose(1, 2).contiguous().view(b * h, t, s)

        queries = queries / (e ** (1/4))
        keys    = keys / (e ** (1/4))
        # Instead of dividing the dot products by sqrt(e), we scale the keys and values.
        # This should be more memory efficient

        # Get dot product of queries and keys, and scale.

        attn_scores = torch.bmm(queries, keys.transpose(1, 2))

        assert attn_scores.size() == (b * h, t, t)

        attn_weights = F.softmax(attn_scores, dim=2) # Dot now has row-wise self-attention probabilities

        # apply the self attention to the values
        out = torch.bmm(attn_weights, values).view(b, h, t, s)

        # swap h, t back, unify heads
        out = out.transpose(1, 2).contiguous().view(b, t, s * h)

        return self.unifyheads(out)

A Transformer Block is based on self-attention (and Layer Normalization, Residual Connections)

In [None]:
class TransformerBlock(nn.Module):

    def __init__(self, emb, heads, mask, seq_length, ff_hidden_mult=4, dropout=0.0, pos_embedding=None):
        super().__init__()

        self.mhattention = MHSelfAttention(emb, heads=heads)

        self.norm1 = nn.LayerNorm(emb)
        self.norm2 = nn.LayerNorm(emb)

        self.ff = nn.Sequential(
            nn.Linear(emb, ff_hidden_mult * emb),
            nn.ReLU(),
            nn.Linear(ff_hidden_mult * emb, emb)
        )

        self.do = nn.Dropout(dropout)

    def forward(self, x):

        attended = self.mhattention(x)

        x = self.norm1(attended + x) #residual

        x = self.do(x)

        fedforward = self.ff(x)

        x = self.norm2(fedforward + x) #residual

        x = self.do(x)

        return x

Let's build a Transformers (a stack of Transformers Blocks) and adapt it for a binary classification task. Its `depth` defines the number of Transformers Blocks

In [None]:
class CTransformer(nn.Module):

    def __init__(self, emb, heads, depth, seq_length, num_tokens, num_classes, max_pool=True, dropout=0.0):
        """
        :param emb: Embedding dimension
        :param heads: nr. of attention heads
        :param depth: Number of transformer blocks
        :param seq_length: Expected maximum sequence length
        :param num_tokens: Number of tokens (usually words) in the vocabulary
        :param num_classes: Number of classes.
        :param max_pool: If true, use global max pooling in the last layer. If false, use global
                         average pooling.
        """
        super().__init__()

        self.num_tokens, self.max_pool = num_tokens, max_pool

        # Token embedding
        self.token_embedding = nn.Embedding(embedding_dim=emb, num_embeddings=num_tokens)
        # Position embedding
        self.pos_embedding = nn.Embedding(embedding_dim=emb, num_embeddings=seq_length)

        tblocks = []
        for i in range(depth):
            tblocks.append(
                TransformerBlock(emb=emb, heads=heads, seq_length=seq_length, mask=False, dropout=dropout))

        self.tblocks = nn.Sequential(*tblocks)

        self.toprobs = nn.Linear(emb, num_classes)

        self.do = nn.Dropout(dropout)

    def forward(self, x):
        """
        :param x: A batch by sequence length integer tensor of token indices.
        :return: predicted log-probability vectors for each token based on the preceding tokens.
        """
        tokens = self.token_embedding(x)
        b, t, e = tokens.size()

        positions = self.pos_embedding(torch.arange(t, device=device))[None, :, :].expand(b, t, e)
        x = tokens + positions
        x = self.do(x)

        x = self.tblocks(x)

        x = x.max(dim=1)[0] if self.max_pool else x.mean(dim=1) # pool over the time dimension

        x = self.toprobs(x)

        return F.log_softmax(x, dim=1) # nn.softmax()

In [None]:
torch.arange(21)
# [0, 0, ..., 0]
# [1, 1, ..., 1]
# ...
# [20, 20, ..., 20]

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20])

### Data Preparation

One of the main concepts of TorchText is the `Field`. These define how your data should be processed. In our sentiment classification task the data consists of both the raw string of the review and the sentiment, either "pos" or "neg".

The parameters of a `Field` specify how the data should be processed.

We use the `TEXT` field to define how the review should be processed, and the `LABEL` field to process the sentiment.


In [None]:
TEXT = data.Field(lower=True, include_lengths=True, batch_first=True) # If no tokenize argument is passed, the default is simply splitting the string on spaces.
LABEL = data.Field(sequential=False)

NUM_CLS = 2
BATCH_SIZE = 4
MAX_LENGTH = 256 #512
EMB_SIZE = 128
HEADS = 8
DEPTH = 3 #Number of self-attention layer
VOC_SIZE = 50000

LR_RATE = 0.0001
WARMUP = 10000

In [None]:
tbw = SummaryWriter(log_dir='./logs') # Tensorboard logging

train, test = datasets.IMDB.splits(TEXT, LABEL)

downloading aclImdb_v1.tar.gz


aclImdb_v1.tar.gz: 100%|██████████| 84.1M/84.1M [00:14<00:00, 5.93MB/s]


In [None]:
# View 'text' and 'label'
print(vars(train.examples[0]))

{'text': ['first', 'of', 'all,', 'i', 'loved', 'bruce', "broughton's", 'music', 'score,', 'very', 'lyrical,', 'and', 'this', 'alone', 'added', 'to', 'the', "film's", 'charm.', 'the', 'best', 'aspect', 'of', 'the', 'movie', 'were', 'the', 'three', 'animals,', 'superlatively', 'voiced', 'by', 'michael', 'j.fox,', 'sally', 'field', 'and', 'the', 'late', 'don', 'ameche.', 'whereas', 'fox', 'has', 'the', 'funniest', 'lines,', 'ameche', 'plays', 'a', 'rather', 'brooding', 'otherwise', 'engaging', 'character(the', 'voice', 'of', 'reason),', 'and', 'field', 'adds', 'wit', 'into', 'a', 'character', 'that', 'is', 'always', 'seen', 'telling', 'chance', 'off.', 'the', 'humans', "weren't", 'as', 'engaging,', 'and', 'sometimes', 'the', 'film', 'dragged,', 'but', 'that', 'is', 'my', 'only', 'complaint.', 'this', 'is', 'one', 'beautiful-looking', 'film,', 'with', 'beautiful', 'close', 'up', 'shots', 'of', 'canada,', 'i', 'believe.', 'although', 'the', 'film', 'itself', 'is', 'quite', 'long,', 'there',

In [None]:
TEXT.build_vocab(train, max_size=VOC_SIZE - 2)
LABEL.build_vocab(train)

In [None]:
train_iter, test_iter = data.BucketIterator.splits((train, test), batch_size=BATCH_SIZE, device=device)

In [None]:
# batch size = 4
for batch in train_iter:

    input = batch.text[0]
    label = batch.label - 1

    print(input)
    print(label)
    break

tensor([[  292,    43,    10,    24,     7, 18489,     0,  7672,   590,   552,
          3394,  2064,  2133,  5502,    10,    20,  9396,    72,     0,     2,
           201,   207,    26,    71,     3,  3717,     0,   643,  3500,     7,
           421,   425,     4,    41,    23,   176,     3,  2657,    19,    48,
             7,   162,    19,   411,     2,  1321,   405,    22,  4379,    60,
           132,   883,    29,   364,    49,   467,     0,     4,    39,   214,
           245,     0,   301,   362,    22,    16,     2, 47820,     2,    38,
           437,  2775,  1352,   197,    10,  2336,  3498,     8,  3490,     0,
            17,   467,   275,  7770,     0, 25170,    11,    22,     0,     4,
             3,  1276,  2121,    11,  4285,    17,     3,  2264,     5,  4825,
         12880,    46,     2,  2336,     7,   443,  5745,     4,   164,     0,
           144,  1587, 13608,  1214, 24076,   313,   292,     7, 16132,     2,
           999,   128,    28,     2,    82,  2039,  

In [None]:
print(f'- nr. of training examples {len(train_iter)}')
print(f'- nr. of test examples {len(test_iter)}')

- nr. of training examples 6250
- nr. of test examples 6250


In [None]:
# create the model
model = CTransformer(emb=EMB_SIZE, heads=HEADS, depth=DEPTH, seq_length=MAX_LENGTH, num_tokens=VOC_SIZE, num_classes=NUM_CLS, max_pool="store_true", dropout=0.2)
model.to(device)

CTransformer(
  (token_embedding): Embedding(50000, 128)
  (pos_embedding): Embedding(256, 128)
  (tblocks): Sequential(
    (0): TransformerBlock(
      (mhattention): MHSelfAttention(
        (tokeys): Linear(in_features=128, out_features=128, bias=False)
        (toqueries): Linear(in_features=128, out_features=128, bias=False)
        (tovalues): Linear(in_features=128, out_features=128, bias=False)
        (unifyheads): Linear(in_features=128, out_features=128, bias=True)
      )
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=128, bias=True)
      )
      (do): Dropout(p=0.2, inplace=False)
    )
    (1): TransformerBlock(
      (mhattention): MHSelfAttention(
        (tokeys): Linear(in_features=128, out_features=128, bias=False)
   

In [None]:
opt = torch.optim.Adam(lr=LR_RATE, params=model.parameters())
sch = torch.optim.lr_scheduler.LambdaLR(opt, lambda i: min(i / (WARMUP / BATCH_SIZE), 1.0))

In [None]:
NUM_EPOCHS = 4

# training loop
seen = 0
for e in range(NUM_EPOCHS):

    print(f'\n epoch {e}')
    model.train()

    for batch in tqdm.tqdm(train_iter):

        opt.zero_grad()

        input = batch.text[0]
        label = batch.label - 1

        if input.size(1) > MAX_LENGTH:
            input = input[:, :MAX_LENGTH]


        out = model(input)
        loss = F.nll_loss(out, label)
        # loss = CrossEntropy(out, label)

        loss.backward()

        # clip gradients
        # Performs gradient clipping. It is used to mitigate the problem of exploding gradients.
        # - If the total gradient vector has a length > 1, we clip it back down to 1.
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        opt.step()
        sch.step()

        seen += input.size(0)
        tbw.add_scalar('classification/train-loss', float(loss.item()), seen)

    with torch.no_grad():

        model.eval()
        tot, cor= 0.0, 0.0

        for batch in test_iter:

            input = batch.text[0]
            label = batch.label - 1

            if input.size(1) > MAX_LENGTH:
                input = input[:, :MAX_LENGTH]
            out = model(input).argmax(dim=1)

            tot += float(input.size(0))
            cor += float((label == out).sum().item())

        acc = cor / tot
        print(f'-- test accuracy {acc:.3}')
        tbw.add_scalar('classification/test-loss', float(loss.item()), e)



 epoch 0


100%|██████████| 6250/6250 [01:48<00:00, 57.85it/s]


-- test accuracy 0.577

 epoch 1


 89%|████████▊ | 5545/6250 [01:34<00:12, 58.61it/s]


KeyboardInterrupt: 

In [None]:
#TODO


model_inference = model.load_weigths('path')

model_inference("I bought that book and I enjoyed the readings")

# 0 o 1