<a href="https://colab.research.google.com/github/achilela/AreaB/blob/main/VIT_Details.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install einops==0.7.0
!pip install torch torchvision

Collecting einops==0.7.0
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m60.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m68.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Dow

In [None]:
# The code was taken from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
# And has only been extended with reference comments to the following blogpost:
# https://blog.mdturp.ch/posts/2024-04-05-visual_guide_to_vision_transformer.html


import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # Blogpost step 10.11
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    # Blogpost step 10.6
    def forward(self, x):
        x = self.norm(x)

        #  Blogpost step 10.1
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        # Blogpost step 10.2-10.3
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        attn = self.dropout(attn)

        # Blogpost step 10.4-10.5
        out = torch.matmul(attn, v)

        # Blogpost step 10.7-10.8
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):

        # Blogpost step 11
        for attn, ff in self.layers:

            # Blogpost steps 10.1-10.10
            x = attn(x) + x # Blogpost step 10.9

            # Blogpost steps 10.11-10.12
            x = ff(x) + x

        return self.norm(x)

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(

            # Blogpost steps 3-4
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),

            # Blogpost step 5
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        # Blogpost step 13
        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):

        # Blogpost steps 3-6
        x = self.to_patch_embedding(img)

        b, n, _ = x.shape

        # Blogpost step 7
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)

        # Blogpost step 8-9
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        # Blogpost step 10-11
        x = self.transformer(x)

        # Blogpost step 12
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # Blogpost step 13
        x = self.to_latent(x)
        return self.mlp_head(x)



# Training the VIT-Model on CIFAR-10

The following code below shows how to train the VIT-Model on the CIFAR-10 dataset.



In [None]:
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
import tqdm

import torchvision.transforms as transforms
import torchvision


transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 64

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')



epochs = 20
lr = 3e-5
gamma = 0.7
seed = 42
device = 'cuda'

model = ViT(
    image_size = 32,
    patch_size = 8,
    num_classes = 10,
    dim = 1024,
    depth = 6,
    heads = 12,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:12<00:00, 13136723.46it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
print("Start training")

model.to(device)

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm.tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)

        # Blogpost step 14
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_test_accuracy = 0
        epoch_test_loss = 0
        for data, label in test_loader:
            data = data.to(device)
            label = label.to(device)

            test_output = model(data)
            test_loss = criterion(test_output, label)

            acc = (test_output.argmax(dim=1) == label).float().mean()
            epoch_test_accuracy += acc / len(test_loader)
            epoch_test_loss += test_loss / len(test_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - test_loss : {epoch_test_loss:.4f} - test_acc: {epoch_test_accuracy:.4f}\n"
    )

Start training


100%|██████████| 782/782 [01:18<00:00,  9.99it/s]


Epoch : 1 - loss : 1.6214 - acc: 0.4142 - test_loss : 1.4091 - test_acc: 0.4976



100%|██████████| 782/782 [01:19<00:00,  9.89it/s]


Epoch : 2 - loss : 1.3250 - acc: 0.5254 - test_loss : 1.3284 - test_acc: 0.5237



100%|██████████| 782/782 [01:18<00:00,  9.90it/s]


Epoch : 3 - loss : 1.1981 - acc: 0.5715 - test_loss : 1.2890 - test_acc: 0.5364



100%|██████████| 782/782 [01:19<00:00,  9.90it/s]


Epoch : 4 - loss : 1.0985 - acc: 0.6082 - test_loss : 1.2491 - test_acc: 0.5607



100%|██████████| 782/782 [01:19<00:00,  9.85it/s]


Epoch : 5 - loss : 1.0038 - acc: 0.6435 - test_loss : 1.2418 - test_acc: 0.5640



100%|██████████| 782/782 [01:19<00:00,  9.90it/s]


Epoch : 6 - loss : 0.9156 - acc: 0.6722 - test_loss : 1.2395 - test_acc: 0.5727



100%|██████████| 782/782 [01:18<00:00,  9.90it/s]


Epoch : 7 - loss : 0.8137 - acc: 0.7111 - test_loss : 1.2375 - test_acc: 0.5813



100%|██████████| 782/782 [01:19<00:00,  9.90it/s]


Epoch : 8 - loss : 0.7163 - acc: 0.7446 - test_loss : 1.2614 - test_acc: 0.5798



100%|██████████| 782/782 [01:18<00:00,  9.90it/s]


Epoch : 9 - loss : 0.6204 - acc: 0.7795 - test_loss : 1.2963 - test_acc: 0.5782



100%|██████████| 782/782 [01:18<00:00,  9.90it/s]


Epoch : 10 - loss : 0.5253 - acc: 0.8134 - test_loss : 1.3674 - test_acc: 0.5874



100%|██████████| 782/782 [01:19<00:00,  9.90it/s]


Epoch : 11 - loss : 0.4400 - acc: 0.8435 - test_loss : 1.4403 - test_acc: 0.5830



100%|██████████| 782/782 [01:19<00:00,  9.90it/s]


Epoch : 12 - loss : 0.3686 - acc: 0.8696 - test_loss : 1.4795 - test_acc: 0.5929



100%|██████████| 782/782 [01:19<00:00,  9.90it/s]


Epoch : 13 - loss : 0.3011 - acc: 0.8940 - test_loss : 1.6237 - test_acc: 0.5827



100%|██████████| 782/782 [01:19<00:00,  9.90it/s]


Epoch : 14 - loss : 0.2511 - acc: 0.9121 - test_loss : 1.6776 - test_acc: 0.5857



100%|██████████| 782/782 [01:19<00:00,  9.90it/s]


Epoch : 15 - loss : 0.2147 - acc: 0.9245 - test_loss : 1.7540 - test_acc: 0.5824



100%|██████████| 782/782 [01:19<00:00,  9.90it/s]


Epoch : 16 - loss : 0.1790 - acc: 0.9358 - test_loss : 1.8704 - test_acc: 0.5832



100%|██████████| 782/782 [01:19<00:00,  9.90it/s]


Epoch : 17 - loss : 0.1587 - acc: 0.9452 - test_loss : 1.8867 - test_acc: 0.5938



100%|██████████| 782/782 [01:19<00:00,  9.90it/s]


Epoch : 18 - loss : 0.1407 - acc: 0.9512 - test_loss : 1.9751 - test_acc: 0.5815



100%|██████████| 782/782 [01:19<00:00,  9.90it/s]


Epoch : 19 - loss : 0.1283 - acc: 0.9551 - test_loss : 1.9614 - test_acc: 0.5878



100%|██████████| 782/782 [01:19<00:00,  9.90it/s]


Epoch : 20 - loss : 0.1155 - acc: 0.9600 - test_loss : 2.0299 - test_acc: 0.5899

