# Vision transformer

## Preliminaries

### Libraries and imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

### Global variables

In [2]:
# MNIST images are 28x28
IMAGE_SIZE = 28

# Divide image into (28/7)x(28/7) patches
PATCH_SIZE = 7
NUM_SPLITS = IMAGE_SIZE // PATCH_SIZE
NUM_PATCHES = NUM_SPLITS ** 2

BATCH_SIZE = 100
EMBEDDING_DIM = 8
NUM_HEADS = 2
NUM_CLASSES = 10
NUM_TRANSFORMER_LAYERS = 4
HIDDEN_DIM = 16
EPOCHS = 3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## The `MNIST` dataset

See [here](https://en.wikipedia.org/wiki/MNIST_database) for details.

In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

train_dataset = datasets.MNIST(
    root="./.data", train=True, transform=transform, download=True
)
test_dataset = datasets.MNIST(
    root="./.data", train=False, transform=transform, download=True
)

# Define data loader with `BATCH_SIZE` and shuffle
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

## Patch Embedding Layer

The first module to implement is a module that will transformed a tensor
of size `BATCH_SIZE` \* 1 \* `IMAGE_SIZE` \* `IMAGE_SIZE` into a tensor
of size `BATCH_SIZE` \* `NUM_PATCHES` \* `EMBEDDING_DIM`. This can be
done by using a `nn.Conv2d` module with both the stride and the kernel
the size of a patch.

In [6]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=1, patch_size=7, embedding_dim=64):
        super().__init__()
        self.patch_size = patch_size
        self.embedding_dim = embedding_dim
        # Use `nn.Conv2d` to split the image into patches
        self.projection = nn.Conv2d(in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # `x` is `BATCH_SIZE` * 1 * `IMAGE_SIZE` * `IMAGE_SIZE`

        # Project `x` into a tensor of size `BATCH_SIZE` * `EMBEDDING_DIM` *
        # `NUM_SPLITS` * `NUM_SPLITS`
        x = self.projection(x)

        # Flatten spatial dimensions to have a tensor of size `BATCH_SIZE` *
        # `EMBEDDING_DIM` * `NUM_PATCHES`
        x = x.flatten(2)

        # Put the `NUM_PATCHES` dimension at the second place to have a tensor
        # of size `BATCH_SIZE` * `NUM_PATCHES`` * `EMBEDDING_DIM`
        x = x.permute([0, 2, 1])

        return x

## Transformer encoder

In [10]:
class TransformerEncoder(nn.Module):
    def __init__(self, embedding_dim, num_heads, hidden_dim):
        super().__init__()
        # Define a `nn.MultiheadAttention` module with `embedding_dim` and
        # `num_heads`. Don't forget to set `batch_first` to `True`
        self.attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, batch_first=True)

        # Define the position-wise feed-forward network using an `nn.Sequential`
        # module, which consists of a linear layer, a GELU activation function,
        # and another linear layer
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, embedding_dim)
        )

        # Define two layer normalization modules
        self.layernorm1 = nn.LayerNorm(embedding_dim)
        self.layernorm2 = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        # Compute self-attention on `x`
        attn_output, _ = self.attention(x, x, x)

        # Skip-connection and first layer normalization
        x = self.layernorm1(x + attn_output)

        # Apply the position-wise feed-forward network
        mlp_output = self.mlp(x)

        # Skip-connection and second layer normalization
        x = self.layernorm2(x + mlp_output)

        return x

## Vision Transformer

In [12]:
class VisionTransformer(nn.Module):
    def __init__(
            self,
            patch_size,
            embedding_dim,
            num_heads,
            num_classes,
            num_transformer_layers,
            hidden_dim,
    ):
        super().__init__()

        # Define a `PatchEmbedding` module
        self.patch_embedding = PatchEmbedding(in_channels=1, patch_size=patch_size, embedding_dim=embedding_dim)

        # Use `nn.Parameter` to define an additional token embedding that will
        # be used to predict the class
        self.cls_token = nn.Parameter(torch.zeros((1, 1, embedding_dim)))

        # Use `nn.Parameter` to define a learnable positional encoding.
        self.position_embedding = nn.Parameter(torch.randn(1, NUM_PATCHES + 1, embedding_dim))

        # Use `nn.init.xavier_uniform_` to initialize the positional embedding
        nn.init.xavier_uniform_(self.position_embedding)

        self.encoder_layers = nn.Sequential(
            *[
                TransformerEncoder(embedding_dim, num_heads, hidden_dim)
                for _ in range(num_transformer_layers)
            ]
        )

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embedding_dim), nn.Linear(embedding_dim, num_classes)
        )

    def forward(self, x):
        # `x` is `BATCH_SIZE` * 1 * `IMAGE_SIZE` * `IMAGE_SIZE`

        # Transform images into embedded patches. It gives a tensor of size
        # `BATCH_SIZE` * `NUM_PATCHES` * `EMBEDDING_DIM`
        x = self.patch_embedding(x)

        # We need to add the embedded classification token at the beginning of
        # each sequence in the minibatch. Use `expand` to duplicate it along the
        # batch size dimension
        batch_size = x.shape[0]
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)

        # Next use `torch.cat` to concatenate `cls_tokens` and `x` to have a
        # tensor of size `BATCH_SIZE` * (NUM_PATCHES + 1) * `EMBEDDING_DIM`
        x = torch.cat((cls_tokens, x), dim = 1)

        # Add the positional encoding
        x += self.position_embedding

        # Apply the stacked transformer modules
        y = self.encoder_layers(x)

        # Select the classification token for each sample in the minibatch.
        # `cls_output` should be of size `BATCH_SIZE` * 1 * `EMBEDDING_DIM`
        cls_output = y[:, 0, :]

        # Use `self.mlp_head` to adapt the output size to NUM_CLASSES.
        out = self.mlp_head(cls_output)

        return out

## Initialize model, loss and optimizer

In [18]:
# Define the `VisionTransformer` model
model = VisionTransformer(PATCH_SIZE, EMBEDDING_DIM, NUM_HEADS, NUM_CLASSES, NUM_TRANSFORMER_LAYERS, HIDDEN_DIM)

# Use cross-entropy loss and AdamW optimizer with a learning rate of 5e-3
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

## Validation loss calculation

In [19]:
def validate_model(model, val_loader, criterion):
    model.eval()
    val_loss = 0
    total, correct = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return val_loss / len(val_loader), accuracy

## Training with Validation

In [20]:
for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Calculate validation loss and accuracy
    val_loss, val_accuracy = validate_model(model, test_loader, criterion)

    print(f"Epoch {epoch}/{EPOCHS}")
    print(f"Train Loss: {total_loss/len(train_loader):.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%")

Epoch 1/3
Train Loss: 0.9254
Val Loss: 0.4883, Val Accuracy: 84.69%
Epoch 2/3
Train Loss: 0.3961
Val Loss: 0.3304, Val Accuracy: 90.47%
Epoch 3/3
Train Loss: 0.3058
Val Loss: 0.2583, Val Accuracy: 92.40%
