**References:**

1. https://arxiv.org/pdf/2010.11929
2. https://medium.com/@brianpulfer/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c

In [1]:
from vision_transformer import *

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST

In [2]:
# Loading data
transform = ToTensor()

train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

In [3]:
# # Defining model and training options
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
# model = MyViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)
# N_EPOCHS = 5
# LR = 0.005

# # Training loop
# optimizer = Adam(model.parameters(), lr=LR)
# criterion = CrossEntropyLoss()
# for epoch in trange(N_EPOCHS, desc="Training"):
#     train_loss = 0.0
#     for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
#         x, y = batch
#         x, y = x.to(device), y.to(device)
#         y_hat = model(x)
#         loss = criterion(y_hat, y)

#         train_loss += loss.detach().cpu().item() / len(train_loader)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#     print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

# # Test loop
# with torch.no_grad():
#     correct, total = 0, 0
#     test_loss = 0.0
#     for batch in tqdm(test_loader, desc="Testing"):
#         x, y = batch
#         x, y = x.to(device), y.to(device)
#         y_hat = model(x)
#         loss = criterion(y_hat, y)
#         test_loss += loss.detach().cpu().item() / len(test_loader)

#         correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
#         total += len(x)
#     print(f"Test loss: {test_loss:.2f}")
#     print(f"Test accuracy: {correct / total * 100:.2f}%")

In [4]:
# Defining model and training options
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(
    "Using device: ",
    device,
    f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "",
)

model = MyViT(
    (1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10
).to(device)
N_EPOCHS = 25
LR = 0.005


Using device:  cuda (NVIDIA GeForce RTX 4060 Laptop GPU)


In [5]:
# Training loop
optimizer = Adam(model.parameters(), lr=LR)
criterion = CrossEntropyLoss()
for epoch in trange(N_EPOCHS, desc="Training"):
    train_loss = 0.0
    for batch in tqdm(
        train_loader, desc=f"Epoch {epoch + 1} in training", leave=False
    ):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)

        train_loss += loss.detach().cpu().item() / len(train_loader)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

Training:  20%|██        | 1/5 [05:57<23:49, 357.47s/it]

Epoch 1/5 loss: 2.11


Training:  40%|████      | 2/5 [11:55<17:53, 357.84s/it]

Epoch 2/5 loss: 1.86


Training:  60%|██████    | 3/5 [18:16<12:16, 368.36s/it]

Epoch 3/5 loss: 1.77


Training:  80%|████████  | 4/5 [24:28<06:09, 369.65s/it]

Epoch 4/5 loss: 1.71


Training: 100%|██████████| 5/5 [30:47<00:00, 369.44s/it]

Epoch 5/5 loss: 1.68





In [6]:
# Test loop
with torch.no_grad():
    correct, total = 0, 0
    test_loss = 0.0
    for batch in tqdm(test_loader, desc="Testing"):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)
        test_loss += loss.detach().cpu().item() / len(test_loader)

        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
        total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")

Testing: 100%|██████████| 79/79 [00:55<00:00,  1.41it/s]

Test loss: 1.66
Test accuracy: 79.95%



