<a href="https://colab.research.google.com/github/ambideXtrous9/Transformer-Vanilla-Encoder-Decoder-Greedy-Decoding/blob/main/Vanilla_Transformer_Encoder_Decoder_Greedy_Decoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Simple PyTorch Transformer Example with Greedy Decoding

Transformers are a game-changing innovation in deep learning.

This model architecture has superseded all variants of RNNs in NLP tasks, and is showing promise to do the same to CNNs in vision tasks.

However, the PyTorch Transformer docs make it a bit difficult to get started.

- There is no explanation of how to do inference
- The tutorial shows an encoder-only transformer

This notebook provides a simple, self-contained example of Transformer:

- using both the encoder and decoder parts
- greedy decoding at inference time

We train on a simple synthetic example, and use PyTorch-Lightning for the training loop.

This post was written by https://twitter.com/sergeykarayev for https://twitter.com/full_stack_dl Tooling Tuesdays.

January 12, 2021

In [None]:
!pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.1.2-py3-none-any.whl (776 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m776.9/776.9 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.2.1-py3-none-any.whl (806 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m806.1/806.1 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics, pytorch_lightning
Successfully installed lightning-utilities-0.10.0 pytorch_lightning-2.1.2 torchmetrics-1.2.1


In [None]:
import math
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.classification import Accuracy


## Data

First, we generate simple input and output data.

Output: random number sequences like [1, 5, 3]

Input: same as output, but with each element repeated twice, e.g. [1, 1, 5, 5, 3, 3]

In [None]:
N = 10000
S = 32  # target sequence length. input sequence will be twice as long
C = 128  # number of "classes", including 0, the "start token", and 1, the "end token"

Y = (torch.rand((N * 10, S - 2)) * (C - 2)).long() + 2  # Only generate ints in (2, 99) range

# Make sure we only have unique rows
Y = torch.tensor(np.unique(Y, axis=0)[:N])
X = torch.repeat_interleave(Y, 2, dim=1)

# Add special 0 "start" and 1 "end" tokens to beginning and end
Y = torch.cat([torch.zeros((N, 1)), Y, torch.ones((N, 1))], dim=1).long()
X = torch.cat([torch.zeros((N, 1)), X, torch.ones((N, 1))], dim=1).long()

# Look at the data
print(X, X.shape)
print(Y, Y.shape)
print(Y.min(), Y.max())

tensor([[  0,   2,   2,  ..., 102, 102,   1],
        [  0,   2,   2,  ...,  44,  44,   1],
        [  0,   2,   2,  ...,  16,  16,   1],
        ...,
        [  0,  14,  14,  ...,  59,  59,   1],
        [  0,  14,  14,  ...,  34,  34,   1],
        [  0,  14,  14,  ...,  18,  18,   1]]) torch.Size([10000, 62])
tensor([[  0,   2,   2,  ...,  93, 102,   1],
        [  0,   2,   2,  ...,  26,  44,   1],
        [  0,   2,   2,  ...,  60,  16,   1],
        ...,
        [  0,  14,  45,  ...,  96,  59,   1],
        [  0,  14,  45,  ...,  40,  34,   1],
        [  0,  14,  45,  ..., 114,  18,   1]]) torch.Size([10000, 32])
tensor(0) tensor(127)


In [None]:
# Wrap data in the simplest possible way to enable PyTorch data fetching
# https://pytorch.org/docs/stable/data.html

BATCH_SIZE = 128
TRAIN_FRAC = 0.8

dataset = list(zip(X, Y))  # This fulfills the pytorch.utils.data.Dataset interface

# Split into train and val
num_train = int(N * TRAIN_FRAC)
num_val = N - num_train
data_train, data_val = torch.utils.data.random_split(dataset, (num_train, num_val))

dataloader_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE)
dataloader_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE)

# Sample batch
x, y = next(iter(dataloader_train))
x, y

(tensor([[ 0, 12, 12,  ..., 47, 47,  1],
         [ 0, 12, 12,  ..., 23, 23,  1],
         [ 0,  6,  6,  ..., 65, 65,  1],
         ...,
         [ 0, 11, 11,  ..., 35, 35,  1],
         [ 0, 10, 10,  ...,  7,  7,  1],
         [ 0, 14, 14,  ..., 50, 50,  1]]),
 tensor([[  0,  12, 115,  ...,  18,  47,   1],
         [  0,  12,  42,  ...,  96,  23,   1],
         [  0,   6,  17,  ...,  96,  65,   1],
         ...,
         [  0,  11,  12,  ...,   4,  35,   1],
         [  0,  10,  98,  ...,  83,   7,   1],
         [  0,  14,  38,  ...,  33,  50,   1]]))

## Model

![](https://media.arxiv-vanity.com/render-output/3715543/Figures/ModalNet-21.png)

In [None]:
class PositionalEncoding(nn.Module):
    """
    Classic Attention-is-all-you-need positional encoding.
    From PyTorch docs.
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


def generate_square_subsequent_mask(size: int):
    """Generate a triangular (size, size) mask. From PyTorch docs."""
    mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


class Transformer(nn.Module):
    """
    Classic Transformer that both encodes and decodes.

    Prediction-time inference is done greedily.

    NOTE: start token is hard-coded to be 0, end token to be 1. If changing, update predict() accordingly.
    """

    def __init__(self, num_classes: int, max_output_length: int, dim: int = 128):
        super().__init__()

        # Parameters
        self.dim = dim
        self.max_output_length = max_output_length
        nhead = 4
        num_layers = 4
        dim_feedforward = dim

        # Encoder part
        self.embedding = nn.Embedding(num_classes, dim)
        self.pos_encoder = PositionalEncoding(d_model=self.dim)
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
            num_layers=num_layers
        )

        # Decoder part
        self.y_mask = generate_square_subsequent_mask(self.max_output_length)
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer=nn.TransformerDecoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
            num_layers=num_layers
        )
        self.fc = nn.Linear(self.dim, num_classes)

        # It is empirically important to initialize weights properly
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        self.fc.weight.data.uniform_(-initrange, initrange)

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Input
            x: (B, Sx) with elements in (0, C) where C is num_classes
            y: (B, Sy) with elements in (0, C) where C is num_classes
        Output
            (B, C, Sy) logits
        """
        encoded_x = self.encode(x)  # (Sx, B, E)
        output = self.decode(y, encoded_x)  # (Sy, B, C)
        return output.permute(1, 2, 0)  # (B, C, Sy)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Input
            x: (B, Sx) with elements in (0, C) where C is num_classes
        Output
            (Sx, B, E) embedding
        """
        x = x.permute(1, 0)  # (Sx, B, E)
        x = self.embedding(x) * math.sqrt(self.dim)  # (Sx, B, E)
        x = self.pos_encoder(x)  # (Sx, B, E)
        x = self.transformer_encoder(x)  # (Sx, B, E)
        return x

    def decode(self, y: torch.Tensor, encoded_x: torch.Tensor) -> torch.Tensor:
        """
        Input
            encoded_x: (Sx, B, E)
            y: (B, Sy) with elements in (0, C) where C is num_classes
        Output
            (Sy, B, C) logits
        """
        y = y.permute(1, 0)  # (Sy, B)
        y = self.embedding(y) * math.sqrt(self.dim)  # (Sy, B, E)
        y = self.pos_encoder(y)  # (Sy, B, E)
        Sy = y.shape[0]
        y_mask = self.y_mask[:Sy, :Sy].type_as(encoded_x)  # (Sy, Sy)
        output = self.transformer_decoder(y, encoded_x, y_mask)  # (Sy, B, E)
        output = self.fc(output)  # (Sy, B, C)
        return output

    def predict(self, x: torch.Tensor) -> torch.Tensor:
        """
        Method to use at inference time. Predict y from x one token at a time. This method is greedy
        decoding. Beam search can be used instead for a potential accuracy boost.

        Input
            x: (B, Sx) with elements in (0, C) where C is num_classes
        Output
            (B, C, Sy) logits
        """
        encoded_x = self.encode(x)

        output_tokens = (torch.ones((x.shape[0], self.max_output_length))).type_as(x).long() # (B, max_length)
        output_tokens[:, 0] = 0  # Set start token
        for Sy in range(1, self.max_output_length):
            y = output_tokens[:, :Sy]  # (B, Sy)
            output = self.decode(y, encoded_x)  # (Sy, B, C)
            output = torch.argmax(output, dim=-1)  # (Sy, B)
            output_tokens[:, Sy] = output[-1:]  # Set the last output token
        return output_tokens


model = Transformer(num_classes=C, max_output_length=y.shape[1])
logits = model(x, y[:, :-1])
print(x.shape, y.shape, logits.shape)
print(x[0:1])
print(model.predict(x[0:1]))

torch.Size([128, 62]) torch.Size([128, 32]) torch.Size([128, 128, 31])
tensor([[  0,   2,   2,  93,  93,  77,  77,  18,  18, 105, 105, 124, 124,  34,
          34,  68,  68,  23,  23,  18,  18, 117, 117,  97,  97,  92,  92, 117,
         117, 119, 119,  13,  13,  84,  84, 125, 125,  81,  81,   7,   7,   8,
           8,  45,  45,  25,  25, 103, 103,  46,  46, 103, 103,  13,  13, 100,
         100, 105, 105,  60,  60,   1]])
tensor([[  0,  49,  23,  23,  23, 108,  23,  23,  23,  23,  23,  23,  23,  23,
          23,  23,  23,  23,  23,  23,  23,  35,  23,  23,  23,  23,  35,  23,
          35,  23,  35,  23]])


In [None]:
from sklearn.metrics import accuracy_score

In [None]:
class LitModel(pl.LightningModule):
    """Simple PyTorch-Lightning model to train our Transformer."""

    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss = nn.CrossEntropyLoss()
        self.val_acc = Accuracy(task="multiclass", num_classes=32)

    def training_step(self, batch, batch_ind):
        x, y = batch
        # Teacher forcing: model gets input up to the last character,
        # while ground truth is from the second character onward.
        logits = self.model(x, y[:, :-1])
        loss = self.loss(logits, y[:, 1:])
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_ind):
        x, y = batch
        logits = self.model(x, y[:, :-1])
        loss = self.loss(logits, y[:, 1:])
        pred = self.model.predict(x)
        correct_predictions = (y == pred).float()

        # Calculate accuracy for each row
        row_accuracies = torch.mean(correct_predictions, dim=1)

        # Overall accuracy (average over all rows)
        accuracy = torch.mean(row_accuracies).item()
        # accuracy = accuracy_score(y.cpu().numpy(), pred.cpu().numpy())
        # self.val_acc(pred, y)
        self.log_dict({"val_acc" : accuracy,
                       "val_loss" : loss,
                       },on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())



In [None]:

model = Transformer(num_classes=C, max_output_length=y.shape[1])
lit_model = LitModel(model)
early_stop_callback = pl.callbacks.EarlyStopping(monitor='val_loss')




In [None]:
trainer = pl.Trainer(max_epochs=5,
                     accelerator="gpu",
                     devices=-1,
                     callbacks=[early_stop_callback])


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(lit_model, dataloader_train, dataloader_val)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type               | Params
-----------------------------------------------
0 | model   | Transformer        | 1.1 M 
1 | loss    | CrossEntropyLoss   | 0     
2 | val_acc | MulticlassAccuracy | 0     
-----------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.379     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


In [None]:
# We can see that the decoding works correctly

x, y = next(iter(dataloader_val))
print('Input:', x[:1])
pred = lit_model.model.predict(x[:1])
print('Truth/Pred:')
print(torch.cat((y[:1], pred)))

Input: tensor([[  0,   2,   2,  93,  93,  77,  77,  18,  18, 105, 105, 124, 124,  34,
          34,  68,  68,  23,  23,  18,  18, 117, 117,  97,  97,  92,  92, 117,
         117, 119, 119,  13,  13,  84,  84, 125, 125,  81,  81,   7,   7,   8,
           8,  45,  45,  25,  25, 103, 103,  46,  46, 103, 103,  13,  13, 100,
         100, 105, 105,  60,  60,   1]])
Truth/Pred:
tensor([[  0,   2,  93,  77,  18, 105, 124,  34,  68,  23,  18, 117,  97,  92,
         117, 119,  13,  84, 125,  81,   7,   8,  45,  25, 103,  46, 103,  13,
         100, 105,  60,   1],
        [  0,   2,  93,  77,  18, 105, 124,  34,  68,  23,  18, 117,  97,  92,
         117, 119,  13,  84, 125,  81,   7,   8,  45,  25, 103,  46, 103,  13,
         100, 105,  60,   1]])


That's it for now! Hope this example was helpful.

Follow us at https://twitter.com/full_stack_dl and https://twitter.com/sergeykarayev for more Tooling Tuesdays posts :)

Check out the official PyTorch docs, a helpful blog post from ScaleAI, and the clearest explanation of the Transformer architecture:

- https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
- https://pgresia.medium.com/making-pytorch-transformer-twice-as-fast-on-sequence-generation-2a8a7f1e7389
- http://peterbloem.nl/blog/transformers