In [None]:
import matplotlib.pyplot as plt
import numpy as np

from ipynb.fs.defs._2_Building_a_ViT import MyVit
from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim import Adam, AdamW
from torch.nn import CrossEntropyLoss
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets

from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST

np.random.seed(0)
torch.manual_seed(0)

In [None]:
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

# INFO: The PyTorch uses a different tensor shape than Tensorflpow
# Tensorflow: [batch_size, height, width, channels]
# Pytorch: [batch_size, channels, height, width]
train_loader = DataLoader(training_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=128)
x_sample, y_sample = next(iter(train_loader))
image_shape = x_sample.shape
print (f"Image shape is: {image_shape} (batch size, channels, height, width)")
plt.matshow(x_sample[0][0], cmap="coolwarm")
plt.show()

In [None]:
device = torch.device("cpu")
print("Using device: ", device)
model = MyVit(x_sample.shape,
              classes=10,
              p_size=7,
              embedded_dimension=8,
              n_heads=4,
              n_blocks=2)
model.to(device)
n_epochs = 10
learning_rate = 0.005
optimizer = AdamW(model.parameters(), lr=learning_rate)
criterion = CrossEntropyLoss()

In [None]:
for epoch in trange(n_epochs, desc="Training"):
    train_loss = 0.0
    batch_counter = 0
    pbar = tqdm(train_loader,
                desc=f"Epoch {epoch + 1}",
                leave=False,
                position=0)
    for batch in pbar:
        batch_counter += 1
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)

        train_loss += loss.detach().cpu().item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pbar.set_description(f"training loss: {train_loss / batch_counter:.2f}")
        
    print(f"Epoch {epoch + 1}/{n_epochs} loss: {(train_loss / len(train_loader)):.2f}")
    
    # Test loop
    batch_counter = 0
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing", leave=False, position=0):
            batch_counter += 1
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)

        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")


torch.save(model, "MNIST_transformer.pt")