In [7]:

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import numpy as np

In [8]:
class PatchEmbed(nn.Module):
    """Split image into patches and then embed them"""

    def __init__(self, img_size:int=224, patch_size:int=16, in_channels:int=1, embed_dim:int=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        # Define a conv layer to extract patches from the image
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        '''Transform the image into a tensor of patches'''
        # (n_samples, in_channels, img_size, img_size)
        x = self.proj(x) # (n_samples, embed_dim, n_patches**0.5, n_patches**0.5)
        x = x.flatten(2) # (n_samples, embed_dim, n_patches)
        x = x.transpose(1, 2)

        return x

In [9]:
class Attention(nn.Module):
    '''Attention mecanism'''

    def __init__(self, dim, n_heads, qkv_bias:bool=False, attn_p:float=0., proj_p:float=0):
        '''qkv_bias : If Ture then we include bias to the query, key adn value projections.

        attn_p : Dropout probability applied to the query, key and value tensors.
        proj_p : Dropout probability applied to the output tensor
        Note : Dropout is only applied during training, not during evaluation or prediction'''
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim // n_heads
        self.scale = self.head_dim ** -0.5
        # Input : embedding | Output : query, key and value vectors of the embedding
        # Note : We could write three seperate linear mapping that do the same thing
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        # Take the concatenates heads and map them into a new space
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)
        
    def forward(self, x):
        n_samples, n_tokens, dim = x.shape
        if dim != self.dim:
            raise ValueError
        
        qkv = self.qkv(x) # (n_samples, n_patches + 1, 3 * dim)
        qkv = qkv.reshape(
            n_samples, n_tokens, 3, self.n_heads, self.head_dim
        ) # (n_samples, n_patches + 1, 3, n_heads, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4) # (3, n_samples, n_heads, n_patches + 1, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        k_t = k.transpose(-2, -1)
        dot_product = (
            q @ k_t
        ) * self.scale # (n_samples, n_heads, n_patches + 1, n_patches + 1)
        attn = dot_product.softmax(dim=-1) # (n_samples, n_heads, n_patches + 1, n_patches + 1)
        attn = self.attn_drop(attn)
        weighted_avg = attn @ v # (n_samples, n_heads, n_patches + 1, head_dim)
        # Flatten last 2 dimension <=> Concatenate each head output
        weighted_avg = weighted_avg.transpose(1, 2) # (n_samples, n_patches + 1, n_heads, head_dim)
        # head_dim = dim // n_heads => Get the same dimesion as input
        weighted_avg = weighted_avg.flatten(2) # (n_samples, n_patches + 1, dim)
        # Final linear projection and dropout
        x = self.proj(weighted_avg) # (n_samples, n_patches + 1, dim)
        x = self.proj_drop(x) # (n_samples, n_patches + 1, dim)

        return x
    
class MLP(nn.Module):
    '''MultiLayer Perception'''
    def __init__(self, in_features, hidden_features, out_features, p=0.):
        '''One hidden layer'''
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.activation = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.dropout = nn.Dropout(p)

    def forward(self, x):
        return self.dropout(self.fc2(self.dropout(self.activation(self.fc1(x)))))


class Block(nn.Module):

    def __init__(self, dim, n_heads, mlp_ratio, qkv_bias, p, attn_p):
        '''mlp_ratio : determine the hidden dimension size of the mlp module with respect to dim'''
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = Attention(dim, n_heads, qkv_bias, attn_p, proj_p=p)
        self.mlp = MLP(
            in_features=dim, 
            hidden_features=int(dim * mlp_ratio), 
            out_features=dim
        )
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    '''Vision transformer'''
    def __init__(self, img_size=384, patch_size=16, in_chans=3, n_classes=1000, 
                 embed_dim=768, depth=12, n_heads=4, mlp_ratio=4., qkv_bias=True, p=0., attn_p=0.):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_channels=in_chans,embed_dim=embed_dim)
        # Learnable parameter taht will represent the first token in the sequence. It has embed_dim elements.
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.n_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=p)
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    n_heads=n_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    p=p,
                    attn_p=attn_p
                )
                for _ in range(depth)
            ]
        )

        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        self.head = nn.Linear(embed_dim, n_classes)

    def forward(self, x):
        # Transform input images into patch embedding
        n_samples = x.shape[0]
        x = self.patch_embed(x)
        # Replicates the class token over the sample dimension
        cls_token = self.cls_token.expand(n_samples, -1, -1) # (n_samples, 1, embed_dim)
        x = torch.cat((cls_token, x), dim=1) # (n_samples, 1 + n_patches, embed_dim)
        x = x + self.pos_embed # (n_samples, 1 + n_patches, embed_dim)
        x = self.pos_drop(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)
        cls_token_final = x[:, 0] # just the cls token
        x = self.head(cls_token_final)
        return x

## Import MNIST Dataset to test the model on classification task

In [10]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import os
import random
import numpy as np
# import matplotlib
# matplotlib.use('TkAgg') # Necessary to run matplotlib
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import albumentations as A
import torch
from torchvision import transforms
from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

print(torch.__version__)
import copy

from collections import OrderedDict

from PIL import Image

from torchvision import datasets
from torchvision import transforms
import numpy as np
import random


train_dataset = MNIST(root='',train = True, download = True, transform=transforms.ToTensor())
valid_dataset = MNIST(root='', train = False, download = True, transform=transforms.ToTensor())

BATCH_SIZE = 16

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)

2.0.0


In [20]:
# Image width equals image height
IMG_SIZE = 28
NUM_CLASSES = 10
EPOCHS = 10
LEARNING_RATE = 0.01
DEVICE = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

model = VisionTransformer(
    img_size=28,
    patch_size=7,
    in_chans=1,
    n_classes=NUM_CLASSES,
    embed_dim=8,
    depth=10,
    n_heads=4
).to(device=DEVICE)




criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    #metrics_calculator = SegmentationMetrics()
    i = 0
    for img, y in tqdm(train_dataloader):
        
        img = img.to(device=DEVICE)
        y = y.to(device=DEVICE)
        # prediction
        y_pred = model(img)  # dim : (batch_size, 10, 224, 224)
        # Calculate loss
        loss = criterion(y_pred, y)

        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient descent or adam step
        optimizer.step()
        i += 1


100%|██████████| 3750/3750 [01:12<00:00, 52.08it/s]
100%|██████████| 3750/3750 [01:11<00:00, 52.28it/s]
100%|██████████| 3750/3750 [01:11<00:00, 52.29it/s]


In [21]:
model.eval()
with torch.no_grad():
    predictions = []
    grountruths = []
    for X, y in tqdm(test_dataloader):
        X = X.to(device=DEVICE)
        pred = model(X)
        predictions = predictions + pred.argmax(axis=1).cpu().numpy().tolist()
        grountruths = grountruths + list(y)
    print((np.array(predictions) == np.array(grountruths)).sum() / len(predictions))
    


100%|██████████| 625/625 [00:04<00:00, 128.91it/s]

0.9159



