# Vision Transformer (ViT) for MNIST

In this notebook, we test a ViT implementation with MNIST dataset.

In [1]:
import torch
import torchvision
# vit-pytorch
from vit_pytorch import ViT
#
import numpy as np
#
import tqdm
import time
# Matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

## Loading the MNIST dataset

In [2]:
torch.manual_seed(42)

DOWNLOAD_PATH = "./data/mnist"
BATCH_SIZE_TRAIN = 100
BATCH_SIZE_TEST = 1000

transform = torchvision.transforms.ToTensor()
# transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
#                                             torchvision.transforms.Normalize([0.1307,], [0.3081,])])

# Dataset
train_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=True, download=True,
                                       transform=transform)
test_set  = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=False, download=True,
                                       transform=transform)
# Dataloader
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE_TRAIN, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_set,  batch_size=BATCH_SIZE_TEST,  shuffle=True)

## Training and Evaluation Functions

In [3]:
# Traning functino for one epoch
def train_epoch(model, optimizer, loss_fn, data_loader, loss_history):
    total_samples = len(data_loader.dataset)
    
    model.train() # Enter the training mode
    
    total_loss = 0
    
    for _i, (data, target) in enumerate(tqdm.tqdm(data_loader)):
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        
        # Accumulate
        total_loss += loss.item()
        
        if _i % 100 == 0:
#             print('[' +  '{:5}'.format(_i * len(data)) + '/' + '{:5}'.format(total_samples) +
#                   ' (' + '{:3.0f}'.format(100 * _i / len(data_loader)) + '%)]  Loss: ' +
#                   '{:6.4f}'.format(loss.item()))
            loss_history.append(loss.item())
    
    avg_loss = total_loss / total_samples
    print('\nAverage train loss: ' + '{:.4f}'.format(avg_loss))

In [4]:
# Evaluation function
def evaluate(model, loss_fn, data_loader, loss_history):
    model.eval() # Enter the evaluation mode
    
    total_samples = len(data_loader.dataset)
    correct_samples = 0
    total_loss = 0
    
    with torch.no_grad():
        for data, target in tqdm.tqdm(data_loader):
            output = model(data)
            # Analyze the result
            loss = loss_fn(output, target) # The loss
            _, pred = torch.max(output, dim=1) # The predicted class
            # Accumulate
            total_loss += loss.item()
            correct_samples += pred.eq(target).sum()
            
    avg_loss = total_loss / total_samples
    loss_history.append(avg_loss)
    print('\nAverage test loss: ' + '{:.4f}'.format(avg_loss) +
          '  Accuracy:' + '{:5}'.format(correct_samples) + '/' +
          '{:5}'.format(total_samples) + ' (' +
          '{:4.2f}'.format(100.0 * correct_samples / total_samples) + '%)\n')

## Generate the Model (ViT) and Train

A ViT model

In [5]:
# # Instantiate a model
# model = ViT(image_size=28, patch_size=7, num_classes=10, channels=1,
#             dim=64, depth=6, heads=8, mlp_dim=128)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
# loss_fn = torch.nn.CrossEntropyLoss()

In [6]:
# Instantiate a model
model_vit = ViT(image_size=28, patch_size=4, num_classes=10, channels=1,
            dim=22, depth=3, heads=1, mlp_dim=22)
optimizer_vit = torch.optim.Adam(model_vit.parameters(), lr=0.003)
loss_fn_vit = torch.nn.CrossEntropyLoss()
#
# print(model_vit)

A simple CNN model

In [7]:
from torch import nn
import torch.nn.functional as F
class CNNnet(nn.Module):
    def __init__(self):
        super(CNNnet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        # return F.log_softmax(x)
        return x # directly output the score

In [8]:
# Instantiate a model
model_cnn = CNNnet()
optimizer_cnn = torch.optim.Adam(model_cnn.parameters(), lr=0.003)
loss_fn_cnn = torch.nn.CrossEntropyLoss()

## Model Size

The tool for measuring the model size

In [9]:
def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    # The total size
    total_size = param_size + buffer_size
    # The scaling for display
    if total_size >= 1024**2:
        size_level = 'MB'
        size_scale = 1.0/1024**2
    else:
        size_level = 'kB'
        size_scale = 1.0/1024
    print('='*30)
    print(f'model paramater size: {(param_size*size_scale) :.3f}{size_level}')
    print(f'model buffer size: {(buffer_size*size_scale) :.3f}{size_level}')
    print('-'*30)
    print(f'model buffer size: {(total_size*size_scale) :.3f}{size_level}')
    print('='*30)

Check the model sizes

In [10]:
get_model_size(model_vit)

model paramater size: 86.062kB
model buffer size: 0.000kB
------------------------------
model buffer size: 86.062kB


In [11]:
get_model_size(model_cnn)

model paramater size: 85.312kB
model buffer size: 0.000kB
------------------------------
model buffer size: 85.312kB


## Train the Mode

In [12]:
def train_loop(model, optimizer, loss_fn, train_loader, test_loader, n_epochs=10):
    start_time = time.time()
    train_loss_history, test_loss_history = list(), list()
    for epoch in range(1, n_epochs+1):
        print('EPOCH:', epoch)
        train_epoch(model, optimizer, loss_fn, train_loader, train_loss_history)
        evaluate(model, loss_fn, test_loader, test_loss_history)

    print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')
    return model

The ViT model

In [13]:
train_loop(model_vit, optimizer_vit, loss_fn_vit, train_loader, test_loader, n_epochs=10)

EPOCH: 1


100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [01:42<00:00,  5.86it/s]



Average train loss: 0.0073


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.49it/s]



Average test loss: 0.0003  Accuracy: 8942/10000 (89.42%)

EPOCH: 2


100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [01:37<00:00,  6.13it/s]



Average train loss: 0.0028


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:07<00:00,  1.36it/s]



Average test loss: 0.0002  Accuracy: 9325/10000 (93.25%)

EPOCH: 3


 68%|██████████████████████████████████████████████████████▎                         | 407/600 [00:58<00:27,  6.93it/s]


KeyboardInterrupt: 

The CNN model

In [None]:
train_loop(model_cnn, optimizer_cnn, loss_fn_cnn, train_loader, test_loader, n_epochs=10)

## Test

In [None]:
def get_prediction(model, data_batch):
    # Prediction by model, in batch
    model.eval()
    pred_scores = model(data_batch)
    _, preds = torch.max(pred_scores, dim=1) # The predicted class
    return preds

In [None]:
def plot_image_tiles(images, preds, labels):
    # plot the first 20 images in the batch, along with the corresponding labels and predictions
    fig = plt.figure(figsize=(25, 4))
    for idx in np.arange(20):
        ax = fig.add_subplot(2, 20//2, idx+1, xticks=[], yticks=[])
        plt.imshow(np.transpose(images[idx], (1, 2, 0))) # (c,h,w) --> (h,w,c)
        ax.set_title(f'p:{preds[idx]}' + (f' gt:{labels[idx]} (x)' if preds[idx]!=labels[idx] else ''))

In [None]:
# Visualize some training data
# obtain one batch of training images
dataiter = iter( train_loader )
images_tensor, labels = dataiter.next()
images = images_tensor.numpy() # convert images to numpy for display

print(f'\nimages[0,0,0,0] = {images[0,0,0,0]}\n')
print(f'\nimages[0].mean() = {images[0].mean()}\n')

# Prediction by model, in batch
preds_vit = get_prediction(model_vit, images_tensor)
preds_cnn = get_prediction(model_cnn, images_tensor)

print(f'{preds_vit = }')
print(f'{preds_cnn = }')


# plot the first 20 images in the batch, along with the corresponding labels and predictions
# fig = plt.figure(figsize=(25, 4))
# for idx in np.arange(20):
#     ax = fig.add_subplot(2, 20//2, idx+1, xticks=[], yticks=[])
#     plt.imshow(np.transpose(images[idx], (1, 2, 0))) # (c,h,w) --> (h,w,c)
#     ax.set_title(f'p:{preds[idx]}' + (f' gt:{labels[idx]} (x)' if preds[idx]!=labels[idx] else ''))
plot_image_tiles(images, preds_vit, labels)
plot_image_tiles(images, preds_cnn, labels)

## Calculate the Model Size