# Training a Vision-Transformer

In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.optim import Adam
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import numpy as np

In [2]:
class PatchEmbedding(nn.Module):
    def __init__(self, d_model, img_size, patch_size, num_channels):
        super().__init__()

        self.d_model = d_model
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_channels = num_channels

        self.linear_project = nn.Conv2d(self.num_channels, self.d_model, kernel_size=self.patch_size, stride=self.patch_size)

    def forward(self, x):
        x = self.linear_project(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_sequence_len):
        super().__init__()
        
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))

        position_encoding = torch.zeros(max_sequence_len, d_model)

        for pos in range(max_sequence_len):
            for i in range(0, d_model, 2):
                if i % 2 == 0:
                    position_encoding[pos, i] = np.sin(pos / (10000 ** (i/d_model)))
                else:
                    position_encoding[pos, i + 1] = np.cos(pos / (10000 ** ((i - 1)/d_model)))

        self.register_buffer('position_encoding', position_encoding.unsqueeze(0))

    def forward(self, x):
        tokens_batch = self.cls_token.expand(x.size()[0], -1, -1)
        x = torch.cat((tokens_batch, x), dim=1)
        x = x + self.position_encoding
        return x

In [4]:
class AttentionHead(nn.Module):
    def __init__(self, d_model, head_size):
        super().__init__()
        self.head_size = head_size

        self.query = nn.Linear(d_model, head_size)
        self.key = nn.Linear(d_model, head_size)
        self.value = nn.Linear(d_model, head_size)

    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        attention = Q @ K.transpose(-2, -1)
        attention = attention / (self.head_size ** 0.5)
        attention = torch.softmax(attention, dim=-1)
        attention = attention @ V
        return attention

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        
        self.head_size = d_model // num_heads

        self.W_o = nn.Linear(d_model, d_model)

        self.heads = nn.ModuleList([AttentionHead(d_model, self.head_size) for _ in range(num_heads)])

    def forward(self, x):
        heads_output = torch.cat([head(x) for head in self.heads], dim=-1)
        output = self.W_o(heads_output)
        return output

In [6]:
class TransformerEncoder(nn.Module):
    def __init__(self, d_model, num_heads, r_mlp=4):
        super().__init__()

        self.d_model = d_model
        self.num_heads = num_heads

        self.ln1 = nn.LayerNorm(d_model)
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, r_mlp * d_model),
            nn.GELU(),
            nn.Linear(r_mlp * d_model, d_model)
        )

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

        

In [7]:
class VisionTransformer(nn.Module):
    def __init__(self, d_model, n_classes, img_size, patch_size, num_channels, num_heads, num_layers):
        super().__init__()

        assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, "Image dimensions must be divisible by the patch size."
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads."

        self.d_model = d_model
        self.n_classes = n_classes
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_heads = num_heads
        
        self.num_patches = (self.img_size[0] * self.img_size[1]) // (self.patch_size[0] * self.patch_size[1])
        self.max_sequence_len = self.num_patches + 1

        self.patch_embedding = PatchEmbedding(d_model, img_size, patch_size, num_channels)
        self.positional_encoding = PositionalEncoding(d_model, self.max_sequence_len)
        self.transformer_encoder = nn.Sequential(*[TransformerEncoder(d_model, num_heads) for _ in range(num_layers)])
        self.classifier = nn.Sequential(
            nn.Linear(d_model, n_classes),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        seq = self.patch_embedding(x)
        seq = self.positional_encoding(seq)
        seq = self.transformer_encoder(seq)
        seq = self.classifier(seq[:, 0])
        return seq

In [12]:
d_model = 9
n_classes = 10
img_size = (32,32)
patch_size = (16,16)
n_channels = 1
n_heads = 3
n_layers = 3
batch_size = 128
epochs = 20
alpha = 0.005

In [13]:
transform = T.Compose([
  T.Resize(img_size),
  T.ToTensor()
])

train_set = MNIST(
  root="./../datasets", train=True, download=True, transform=transform
)
test_set = MNIST(
  root="./../datasets", train=False, download=True, transform=transform
)

train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_set, shuffle=False, batch_size=batch_size)

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

transformer = VisionTransformer(d_model, n_classes, img_size, patch_size, n_channels, n_heads, n_layers).to(device)

optimizer = Adam(transformer.parameters(), lr=alpha)
criterion = nn.CrossEntropyLoss()

for epoch in range(epochs):

  training_loss = 0.0
  for i, data in enumerate(train_loader, 0):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)

    optimizer.zero_grad()

    outputs = transformer(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    training_loss += loss.item()

  print(f'Epoch {epoch + 1}/{epochs} loss: {training_loss  / len(train_loader) :.3f}')

Using device:  cpu 
Epoch 1/20 loss: 1.751
Epoch 2/20 loss: 1.651
Epoch 3/20 loss: 1.637
Epoch 4/20 loss: 1.617
Epoch 5/20 loss: 1.550
Epoch 6/20 loss: 1.539
Epoch 7/20 loss: 1.536
Epoch 8/20 loss: 1.533
Epoch 9/20 loss: 1.534
Epoch 10/20 loss: 1.530
Epoch 11/20 loss: 1.528
Epoch 12/20 loss: 1.527
Epoch 13/20 loss: 1.525
Epoch 14/20 loss: 1.523
Epoch 15/20 loss: 1.522
Epoch 16/20 loss: 1.521
Epoch 17/20 loss: 1.522
Epoch 18/20 loss: 1.519
Epoch 19/20 loss: 1.520
Epoch 20/20 loss: 1.518


In [11]:
correct = 0
total = 0

with torch.no_grad():
  for data in test_loader:
    images, labels = data
    images, labels = images.to(device), labels.to(device)

    outputs = transformer(images)

    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
  print(f'\nModel Accuracy: {100 * correct // total} %')


Model Accuracy: 83 %
