# Vision Transformer



### Installing Pytorch lightning
Deep learning framework that organizes Pytorch code and making training easier

In [2]:
!pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.0.5-py3-none-any.whl (722 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/722.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━[0m [32m501.8/722.4 kB[0m [31m14.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m722.4/722.4 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.0.1-py3-none-any.whl (729 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m729.2/729.2 kB[0m [31m47.3 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.7.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.9.0-py3-none-any.whl (23 kB)
Installing collected packages: lightning-utilities, torchmetrics, pytorch_lightning
Successfully installed lightning-utilities-0.9.0 pytorch_lightning-2.0.5 torchmetric

### Necessary import

In [3]:
import os
from PIL import Image
import torchvision
import torch
import math
import torch.nn as nn
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
import pytorch_lightning as pl
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import StepLR

## Split images into patches
This class splits the image considering it as a tensor using unfold fonction

In [4]:
class ImagePatching(nn.Module):
   def __init__(self, image_size, patch_size):
    super().__init__()
    self.image_size = image_size
    self.patch_size = patch_size
    self.nb_patch = (image_size // patch_size)**2

   def forward(self, image):
      unfold = nn.Unfold(self.patch_size, stride = self.patch_size)
      patches = unfold(image)
      patches = patches.permute(0, 2, 1)
      # return ( batch lenght, patches numbre by image, d_vectors)
      return patches

   def print_patches(self,patches):
      """ Display patch images in a grid by repositioning them
       to form the original image using matplotlib  """
      nb = int(math.sqrt( self.nb_patch ))
      fig, axes = plt.subplots(nb, nb, figsize=(8, 8))
      for i, im in  enumerate(patches) :
        im = im.permute(1, 2, 0)
        row = i // nb +1
        col = i % nb + 1
        ax = axes[row][col]
        ax.imshow(im.detach().cpu().numpy(), vmin=0, vmax=1)
        ax.axis("off")
        plt.tight_layout()
      plt.show()



## Embedding Tokens
Creat embedding tokens by :
<ul>
  <li> Split images into patches </li>
  <li> Flatten these patches and apply a linear projection </li>
  <li> Add Classification token </li>
  <li> Add Position Embedding to each token</li>


In [5]:
class Embedding_tokens(nn.Module):
  def __init__(self, image_size, n_channels, patch_size, hidden_dim):
    super().__init__()
    self.image_size = image_size
    self.patch_size = patch_size
    self.hidden_dim = hidden_dim
    self.patching = ImagePatching(image_size, patch_size)
    # learnable classification token
    self.class_token = nn.Parameter(torch.rand(1, hidden_dim))
    self.linear = nn.Linear(patch_size*patch_size*n_channels, hidden_dim)

  def positional_encodings(self, x ):
    """
    Provide positional encodings by folowing these steps :
    - Add positional embedding to patch embeddings with sine functions
    - Create a tensor from 0 to 49999 integers
    - calculates the division term to reduce the values ​​of the positions
    - zero-filled tensor used to store generated positional encodings
    - the values ​​of the sines and cosines are alternated
    - element-by-element sum of the sequence with the corresponding position encodings    """

    position = torch.arange(0,50000).unsqueeze(1)
    div_term = torch.exp(torch.arange(0,self.hidden_dim,2)*(-math.log(10000)/self.hidden_dim))
    pe = torch.zeros(50000, 1 , self.hidden_dim)
    pe[:, 0, 0::2] = torch.sin(position * div_term)
    pe[:, 0, 1::2] = torch.cos(position * div_term)

    return x + pe[:x.size(0)].to(x.device)


  def forward(self, image):
    patches = self.patching(image)
    patches = self.linear(patches)
    # adding classification token to tokens
    patches = torch.stack([torch.vstack((self.class_token, patches[i])) for i in range(len(patches))])
    patches = self.positional_encodings(patches)
    # shape (N_batch,nb_patches for one image, d_vectors(hidden_dim))
    return patches



## Multi-head attention



In [6]:
class MultiHead_Attention(nn.Module):
  def __init__(self, d_model, num_heads):
     super().__init__()
     self.d_model = d_model
     self.num_heads = num_heads
     self.dk = d_model // num_heads
     self.query_projection = nn.ModuleList([nn.Linear(self.dk, self.dk) for _ in range(self.num_heads)])
     self.key_projection = nn.ModuleList([nn.Linear(self.dk, self.dk) for _ in range(self.num_heads)])
     self.value_projection = nn.ModuleList([nn.Linear(self.dk, self.dk) for _ in range(self.num_heads)])
     self.output_projection = nn.Linear(self.d_model, self.d_model)
     self.softmax = nn.Softmax(dim=-1)


  def forward(self, sequences):
    result = []
    for sequence in sequences:
        seq_result = []
        for head in range(self.num_heads):
            q_mapping = self.query_projection[head]
            k_mapping = self.key_projection[head]
            v_mapping = self.value_projection[head]

            seq = sequence[:, head * self.dk: (head + 1) * self.dk]

            q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

            attention = self.softmax(q @ k.T / (self.dk ** 0.5))
            seq_result.append(attention @ v)
        result.append(torch.hstack(seq_result))
    attention_output = torch.cat([torch.unsqueeze(r, dim=0) for r in result])
    return self.output_projection(attention_output)


## Feed Forward Neural Network


In [7]:
class Feed_Forward_NN(nn.Module):
  def __init__(self, d_model, hidden_dim):
    super().__init__()
    self.linear_layer1 = nn.Linear(d_model, hidden_dim)
    self.linear_layer2 = nn.Linear(hidden_dim, d_model)
    self.gelu = nn.GELU()

  def forward(self,x):
    x = self.linear_layer1(x)
    x = self.gelu(x)
    x = self.linear_layer2(x)
    return x



# Encoder Layer

Dropout :
***To boost the performance on the smaller datasets, we optimize three basic regularization parameters – weight decay, dropout, and label smoothing.***
[d'apres](https://arxiv.org/pdf/2010.11929.pdf)

In [8]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()
    self.self_attention = MultiHead_Attention(d_model, num_heads)
    self.feed_forward = Feed_Forward_NN(d_model, d_model)
    # applies normalisation
    self.norm_layer1 = nn.LayerNorm(d_model)
    self.norm_layer2 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout()

  def forward(self, x):
    # calcule attention, connexion residuelle et normalisation
    x_norm = self.norm_layer1(x)
    attention_output = self.dropout(self.self_attention(x_norm))
    x = x + attention_output
    # feed forward NN, connexion residuelle et normalisation
    x_norm = self.norm_layer2(x)
    ffnn_output = self.dropout(self.feed_forward(x_norm))
    x = x + ffnn_output

    return x



# Transformer Encoder

In [9]:
class TransformerEncoder(nn.Module):
  def __init__(self, nb_block, d_model=256, num_heads=2):
    super().__init__()
    self.nb_block = nb_block
    self.encoder_blocks = nn.ModuleList([EncoderLayer(d_model, num_heads) for _ in range(nb_block)])

  def forward(self, tokens):
    out = tokens
    for encoder in self.encoder_blocks :
      out = encoder(out)
    return out


## Classifier (MLP)

In [10]:
class Classifier(nn.Module):
  def __init__(self, d_model, out_dim):
    super().__init__()
    self.d_model = d_model
    self.linear = nn.Linear(d_model, out_dim)
    self.softmax = nn.Softmax(dim=-1)

  def forward(self, x):
    x = self.linear(x)
    return self.softmax(x)

#  Vit Transformer

In [11]:
# vit block
class Vit_Block(nn.Module):
  def __init__(self, nb_block, nb_heads, image_size, n_channels, patch_size, nb_classes):
    super().__init__()
    self.d_model = 3*(image_size//patch_size)**2
    self.nb_block = nb_block
    self.nb_heads = nb_heads
    self.transformer = TransformerEncoder(nb_block, self.d_model, nb_heads)
    self.embedding_tokens = Embedding_tokens(image_size, n_channels, patch_size, self.d_model)
    self.mlp = Classifier(self.d_model, nb_classes)

  def forward(self, x):
    x = self.embedding_tokens(x)
    transformed = self.transformer(x)
    # getting the classification token
    out = transformed[:,0]
    return self.mlp(out)


In [12]:
class VitTransformer(pl.LightningModule):
  CLS = ('plane', 'car', 'bird', 'cat',
         'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  def __init__(self, nb_block, nb_heads, image_size, n_channels, patch_size, nb_classes):
    super().__init__()

    self.nb_block = nb_block
    self.nb_heads = nb_heads
    self.vit_block = Vit_Block(nb_block, nb_heads, image_size, n_channels, patch_size, nb_classes)
    self.loss = nn.CrossEntropyLoss()

  def compute_loss(self, y, y_hat):
    return self.loss(y, y_hat)

  def forward(self, x):
    return self.vit_block(x)

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.vit_block.parameters(), lr=1e-4)
    return optimizer

  def training_step(self, batch, batch_idx):
    x, y = batch
    out_vit = self.vit_block(x)
    model_loss = self.compute_loss(out_vit, y)
    self.log("training_loss", model_loss, prog_bar=True, logger=True)
    return model_loss

  def validation_step(self, batch, batch_idx):
    x, y = batch
    out_vit = self.vit_block(x)
    model_loss = self.compute_loss(out_vit, y)

    correct = torch.sum(torch.argmax(out_vit, dim=1) == y).detach().to(x.device).item() / x.size()[0]

    self.log_dict({"validation_loss": model_loss, "validation_accuracy": correct}, prog_bar=True, logger=True)
    return model_loss

  def on_test_start(self) :
    self.correct = 0
    self.total = 0

  def test_step(self, batch, batch_idx):
    x, y = batch
    out_vit = self.vit_block(x)
    model_loss = self.compute_loss(out_vit, y)

    self.correct = self.correct + torch.sum(torch.argmax(out_vit, dim=1) == y).detach().to(x.device).item()
    self.total =  self.total + x.size(0)

    if(batch_idx==0):
      self.predicted = torch.argmax(out_vit, dim=1)
      self.real = y

    self.log("test_loss", model_loss, prog_bar=True, logger=True)
    return model_loss

  def on_test_end(self):
    real_classes = ''
    predicted_classes = ''
    for i in range(self.real.size(0)) :
        real_classes = real_classes + ' ' + self.CLS[self.real[i]]
        predicted_classes = predicted_classes + ' ' + self.CLS[self.predicted[i]]
    print("real classes : ",real_classes)
    print("predicted classes : ",real_classes)
    accuracy = (self.correct/self.total)*100
    print("Accuracy ",accuracy,"%")


# Cifar 10 DataModule

In [13]:
# Create the datamodule for CIFAR-10 dataset
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./", batch_size=32, image_size=32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                         torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                                                         torchvision.transforms.Resize(torch.Size([image_size, image_size]))])


    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)


    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            CIFAR10(root=".", train=True, transform=self.transform),
            batch_size=self.batch_size,
            shuffle=True,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            CIFAR10(root=".", train=False, transform=self.transform),
            batch_size=self.batch_size,
            shuffle=False,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            CIFAR10(root=".", train=False, transform=self.transform),
            batch_size=self.batch_size,
            shuffle=False,
        )

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            self.cifar10_train = self.train_dataloader()
            # split data into train and validation
            self.cifar10_val = self.val_dataloader()

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.cifar10_test = self.test_dataloader()


# Test training

In [None]:
# create the model
model = VitTransformer(nb_block=3, nb_heads=6, image_size=32, n_channels=3, patch_size=4, nb_classes = 10)
dataset = CIFAR10DataModule(data_dir = os.getcwd(), batch_size=32, image_size=32)
#trainer = pl.Trainer(max_epochs=30, accelerator="cuda", devices=-1)
trainer = pl.Trainer(max_epochs=30)
trainer.fit(model, dataset)


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 59664360.88it/s]


Extracting /content/cifar-10-python.tar.gz to /content
Files already downloaded and verified


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type             | Params
-----------------------------------------------
0 | vit_block | Vit_Block        | 404 K 
1 | loss      | CrossEntropyLoss | 0     
-----------------------------------------------
404 K     Trainable params
0         Non-trainable params
404 K     Total params
1.617     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
dataset.setup(stage="test")
trainer.test(datamodule=dataset)
