# importing libraries

In [None]:
import numpy as np

from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST

# Design of multi-head self-attention

In [73]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MultiHeadSelfAttention, self).__init__()
        self.d=d
        self.n_heads=n_heads
        assert d%n_heads==0, f"Can't divide dimension {d} into {n_heads} heads"
        d_head=int(d/n_heads)
        self.q_mappings=nn.ModuleList([nn.Linear(d_head,d_head) for _ in range(self.n_heads)])
        self.k_mappings=nn.ModuleList([nn.Linear(d_head,d_head) for _ in range(self.n_heads)])
        self.v_mappings=nn.ModuleList([nn.Linear(d_head,d_head) for _ in range(self.n_heads)])
        self.d_head=d_head
        self.softmax=nn.Softmax(dim=-1)
        
    def forward(self, sequences):
        result=[]
        for sequence in sequences:
            seq_result=[]
            for head in range(self.n_heads):
                q_mapping=self.q_mappings[head]
                k_mapping=self.k_mappings[head]
                v_mapping=self.v_mappings[head]
                    
                seq=sequence[:,head*self.d_head:(head+1)*self.d_head]
                q,k,v=q_mapping(seq), k_mapping(seq), v_mapping(seq)
                    
                attention=self.softmax(q@k.T/(self.d_head**0.5))
                seq_result.append(attention@v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

# Encoder block

In [74]:
class SimpleVitBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(SimpleVitBlock, self).__init__()
        self.hidden_d=hidden_d
        self.n_heads=n_heads
        
        self.norm1=nn.LayerNorm(hidden_d)
        self.mhsa=MultiHeadSelfAttention(hidden_d, n_heads)
        self.norm2=nn.LayerNorm(hidden_d)
        self.mlp=nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio*hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio*hidden_d, hidden_d)
        )
    def forward(self, x):
        out=x+self.mhsa(self.norm1(x))
        out=out+self.mlp(self.norm2(out))
        return out

# My Vision Transformer

In [75]:
class SimpleVit(nn.Module):
    def __init__(self, chw=(1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10):
        
        super(SimpleVit, self).__init__()
        self.chw=chw
        self.n_patches=n_patches
        self.hidden_d=hidden_d
        self.n_blocks=n_blocks
        self.n_heads=n_heads
        
        assert self.chw[1]%n_patches==0, "input shape not entirely divisible by number of patches"
        assert self.chw[2]%n_patches==0, "input shape not entirely divisible by number of patches"
        self.patch_size=(chw[1]/n_patches, chw[2]/n_patches)
        
        self.input_d=int(chw[0] * self.patch_size[0]*self.patch_size[1])
        self.linear_mapper=nn.Linear(self.input_d, self.hidden_d)
        
        self.class_token=nn.Parameter(torch.randn(1, self.hidden_d))
        
        self.pos_embed=nn.Parameter(get_positional_embeddings(self.n_patches**2+1, self.hidden_d))
        self.pos_embed.requires_grad=False
        
        self.blocks=nn.ModuleList([SimpleVitBlock(hidden_d, n_heads) for _ in range(n_blocks)])
        
        self.mlp=nn.Sequential(
            nn.Linear(self.hidden_d, out_d),
            nn.Softmax(dim=1)
        )
    def forward(self, images):
        n, c, h, w = images.shape
        patches=patchify(images, self.n_patches)
        tokens=self.linear_mapper(patches)
        tokens=torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
        pos_embed=self.pos_embed.repeat(n,1,1)
        out=tokens+pos_embed
        for block in self.blocks:
            out=block(out)
        out=out[:,0]
        out=self.mlp(out)
        return out

# Data Loading, training, and testing

In [76]:
def main():
    # Loading data
    transform = ToTensor()

    train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
    test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

    # Defining model and training options
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
    model = SimpleVit((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)
    N_EPOCHS = 5
    LR = 0.005

    # Training loop
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()
    for epoch in trange(N_EPOCHS, desc="Training"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")
    # Test loop
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

# Calling main function

The simple vision transformer is trained for 5 epochs over the MNIST dataset to check the performance of the transformer.

In [77]:
if __name__=='__main__':
    main()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./../datasets\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 21301203.99it/s]


Extracting ./../datasets\MNIST\raw\train-images-idx3-ubyte.gz to ./../datasets\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./../datasets\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 172145.26it/s]


Extracting ./../datasets\MNIST\raw\train-labels-idx1-ubyte.gz to ./../datasets\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./../datasets\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 40448776.73it/s]


Extracting ./../datasets\MNIST\raw\t10k-images-idx3-ubyte.gz to ./../datasets\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./../datasets\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<?, ?it/s]


Extracting ./../datasets\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./../datasets\MNIST\raw

Using device:  cpu 


Training:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/469 [00:01<12:19,  1.58s/it][A
Epoch 1 in training:   0%|          | 2/469 [00:02<07:11,  1.08it/s][A
Epoch 1 in training:   1%|          | 3/469 [00:02<05:20,  1.45it/s][A
Epoch 1 in training:   1%|          | 4/469 [00:02<04:34,  1.70it/s][A
Epoch 1 in training:   1%|          | 5/469 [00:03<04:07,  1.88it/s][A
Epoch 1 in training:   1%|▏         | 6/469 [00:03<03:56,  1.95it/s][A
Epoch 1 in training:   1%|▏         | 7/469 [00:04<03:51,  2.00it/s][A
Epoch 1 in training:   2%|▏         | 8/469 [00:04<04:01,  1.91it/s][A
Epoch 1 in training:   2%|▏         | 9/469 [00:05<04:04,  1.88it/s][A
Epoch 1 in training:   2%|▏         | 10/469 [00:05<03:45,  2.04it/s][A
Epoch 1 in training:   2%|▏         | 11/469 [00:06<03:52,  1.97it/s][A
Epoch 1 in training:   3%|▎         | 12/469 [00:06<03:37,  2.10it/s][A
Epoch 1 in training: 

Epoch 1/5 loss: 2.17



Epoch 2 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 1/469 [00:00<03:27,  2.26it/s][A
Epoch 2 in training:   0%|          | 2/469 [00:00<03:31,  2.21it/s][A
Epoch 2 in training:   1%|          | 3/469 [00:01<03:32,  2.20it/s][A
Epoch 2 in training:   1%|          | 4/469 [00:01<03:41,  2.10it/s][A
Epoch 2 in training:   1%|          | 5/469 [00:02<03:33,  2.18it/s][A
Epoch 2 in training:   1%|▏         | 6/469 [00:02<03:32,  2.18it/s][A
Epoch 2 in training:   1%|▏         | 7/469 [00:03<03:41,  2.09it/s][A
Epoch 2 in training:   2%|▏         | 8/469 [00:03<03:36,  2.13it/s][A
Epoch 2 in training:   2%|▏         | 9/469 [00:04<03:36,  2.13it/s][A
Epoch 2 in training:   2%|▏         | 10/469 [00:04<03:40,  2.08it/s][A
Epoch 2 in training:   2%|▏         | 11/469 [00:05<03:40,  2.07it/s][A
Epoch 2 in training:   3%|▎         | 12/469 [00:05<03:40,  2.07it/s][A
Epoch 2 in training:   3%|▎         | 13/469 [00:06<03:34,  2.13it/s

Epoch 2/5 loss: 2.11



Epoch 3 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 3 in training:   0%|          | 1/469 [00:00<04:14,  1.84it/s][A
Epoch 3 in training:   0%|          | 2/469 [00:01<03:53,  2.00it/s][A
Epoch 3 in training:   1%|          | 3/469 [00:01<03:59,  1.95it/s][A
Epoch 3 in training:   1%|          | 4/469 [00:02<03:52,  2.00it/s][A
Epoch 3 in training:   1%|          | 5/469 [00:02<03:45,  2.05it/s][A
Epoch 3 in training:   1%|▏         | 6/469 [00:02<03:47,  2.03it/s][A
Epoch 3 in training:   1%|▏         | 7/469 [00:03<03:52,  1.99it/s][A
Epoch 3 in training:   2%|▏         | 8/469 [00:03<03:45,  2.05it/s][A
Epoch 3 in training:   2%|▏         | 9/469 [00:04<03:56,  1.94it/s][A
Epoch 3 in training:   2%|▏         | 10/469 [00:05<03:49,  2.00it/s][A
Epoch 3 in training:   2%|▏         | 11/469 [00:05<03:49,  2.00it/s][A
Epoch 3 in training:   3%|▎         | 12/469 [00:05<03:39,  2.08it/s][A
Epoch 3 in training:   3%|▎         | 13/469 [00:06<03:33,  2.14it/s

Epoch 3/5 loss: 2.07



Epoch 4 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 4 in training:   0%|          | 1/469 [00:00<04:07,  1.89it/s][A
Epoch 4 in training:   0%|          | 2/469 [00:01<04:02,  1.93it/s][A
Epoch 4 in training:   1%|          | 3/469 [00:01<03:58,  1.95it/s][A
Epoch 4 in training:   1%|          | 4/469 [00:02<03:56,  1.96it/s][A
Epoch 4 in training:   1%|          | 5/469 [00:02<03:58,  1.94it/s][A
Epoch 4 in training:   1%|▏         | 6/469 [00:03<03:56,  1.96it/s][A
Epoch 4 in training:   1%|▏         | 7/469 [00:03<03:56,  1.95it/s][A
Epoch 4 in training:   2%|▏         | 8/469 [00:04<03:58,  1.93it/s][A
Epoch 4 in training:   2%|▏         | 9/469 [00:04<04:02,  1.90it/s][A
Epoch 4 in training:   2%|▏         | 10/469 [00:05<04:13,  1.81it/s][A
Epoch 4 in training:   2%|▏         | 11/469 [00:05<04:15,  1.79it/s][A
Epoch 4 in training:   3%|▎         | 12/469 [00:06<04:08,  1.84it/s][A
Epoch 4 in training:   3%|▎         | 13/469 [00:06<04:11,  1.82it/s

Epoch 4/5 loss: 2.00



Epoch 5 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 5 in training:   0%|          | 1/469 [00:00<04:12,  1.85it/s][A
Epoch 5 in training:   0%|          | 2/469 [00:01<04:04,  1.91it/s][A
Epoch 5 in training:   1%|          | 3/469 [00:01<04:00,  1.93it/s][A
Epoch 5 in training:   1%|          | 4/469 [00:02<03:52,  2.00it/s][A
Epoch 5 in training:   1%|          | 5/469 [00:02<03:52,  1.99it/s][A
Epoch 5 in training:   1%|▏         | 6/469 [00:03<03:54,  1.98it/s][A
Epoch 5 in training:   1%|▏         | 7/469 [00:03<03:54,  1.97it/s][A
Epoch 5 in training:   2%|▏         | 8/469 [00:04<03:55,  1.96it/s][A
Epoch 5 in training:   2%|▏         | 9/469 [00:04<03:50,  1.99it/s][A
Epoch 5 in training:   2%|▏         | 10/469 [00:05<03:56,  1.94it/s][A
Epoch 5 in training:   2%|▏         | 11/469 [00:05<03:49,  1.99it/s][A
Epoch 5 in training:   3%|▎         | 12/469 [00:06<03:51,  1.98it/s][A
Epoch 5 in training:   3%|▎         | 13/469 [00:06<03:46,  2.01it/s

Epoch 5/5 loss: 1.96


Testing: 100%|██████████| 79/79 [00:22<00:00,  3.44it/s]

Test loss: 1.97
Test accuracy: 49.10%



