##### Implementation of ViT From Scratch With Pytorch

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

device

'cuda'

In [4]:
#Implementation of the custom patch embedding
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, embed_dim, patch_size):
        super(PatchEmbedding, self).__init__()
        self.conv2d = nn.Conv2d(
            embed_dim, in_channels, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        x = self.conv2d(x)
        x = x.flatten(start_dim=2)
        return x.transpose(1, 2)

In [5]:
class ViT(nn.Module):
    def __init__(self,img_size, patch_size = 16, in_channels = 3, num_classes = 1000, embed_dim = 768, depth = 12, n_head = 12, ff_dim = 3072, dropout = 0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(embed_dim, in_channels, patch_size)
        cls_init = torch.randn(1, 1, embed_dim) * 0.02
        self.cls_token = nn.Parameter(cls_init)
        num_patches = (img_size // patch_size) ** 2
        pos_init = torch.randn(1, num_patches + 1, embed_dim) * 0.02
        self.pos_embed = nn.Parameter(pos_init)
        self.dropout = nn.Dropout(dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=n_head, dim_feedforward=ff_dim, dropout=dropout,  activation="gelu", batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.layernorm = nn.LayerNorm(embed_dim)
        self.output = nn.Linear(embed_dim, num_classes)


    def forward(self, x):
        z = self.patch_embed(x)
        cls_expd = self.cls_token.expand(z.shape[0], -1, -1)
        z = torch.cat((cls_expd, z), dim=1)
        z = z + self.pos_embed
        z = self.dropout(z)
        z = self.encoder(z)
        z = self.layernorm(z[:, 0])
        logits = self.output(z)
        return logits


In [6]:
vit_model = ViT(
    img_size = 224,
    patch_size = 16,
    in_channels = 3,
    num_classes = 1000,
    embed_dim = 768,
    depth = 12,
    n_head = 12,
    ff_dim = 3072,
    dropout = 0.1,
)
batch = torch.randn(4, 3, 224, 224)
logits = vit_model(batch)

In [8]:
logits.shape

torch.Size([4, 1000])

##### Fine Tuning a pretrained ViT

In [None]:
#Importing pet dataset
from datasets import load_dataset
pets = load_dataset('timm/oxford-iiit-pet')

README.md: 0.00B [00:00, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


data/train-00000-of-00001.parquet:   0%|          | 0.00/378M [00:00<?, ?B/s]

In [None]:
####