<a href="https://colab.research.google.com/github/Mohamad-Atif1/paper2code/blob/main/Transformers/ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**This** notebook is made by **Eng. Mohammed Alshabrawi**

**An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale**

This is the second notebook for Transfromers. We will impelement vision transformer **(ViT)** from scratch!

It is pretry simple if you understand the transformer architecture [Transformer from scratch](https://github.com/Mohamad-Atif1/paper2code/blob/main/Transformers/Transformer.ipynb)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

Images are 3D data (height, width, and color channels) while Transformers are designed for sequence 1D data (sequence of words or tokens). So, the first step is to change the number of image dimensions into a **sequence of 1D embeddings**.

This achieved by:

- Splitting the image into fixed-size, non-overlapping patches. For example, a 224x224 image might be divided into 196 patches, where each patch is 16x16 pixels.

- Applying linear projection to each of these 2D patches. This learnable transformation maps the raw pixel values of each patch into a higher-dimensional 1D vector, known as a patch embedding.  This process is often  implemented using a Conv layer with a kernel size and stride equal to the patch size and the number of kernels is the embedding size

- Finally, we will have sequence of 1D embeddings (196 patches)


Also, we will add Positional Embedding because Transformers look at all parts of a sequence at once, without knowing the position of the patches. In the original paper, Positional Embedding vectors are learnable vectors

<img src="https://miro.medium.com/v2/resize:fit:1400/1*tjEPjCT4Os-mRI3cTn0jVg.png" height=400>

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self,img_size,patch_size,embed_size,in_channels):
        super(PatchEmbedding,self).__init__()
        self.embed_size = embed_size
        # number of patches = (Height * Width) / (patch_height * patch_width)
        # in the orginal paper, images and patches are square (Height = Width), so
        # number of patches = (Highet^2) / (patch_height^2) = (Highet / patch_height) ^2
        self.n_patches = (img_size // patch_size) ** 2
        assert img_size % patch_size == 0, f"Image size {img_size} must be divisible by patch size {patch_size}"

        # Extracting "embed_size" features of each patch seperatelly using Conv2d
        # In the paper it is named " Linear Projection of Flattend Patches"
        self.projection = nn.Conv2d(in_channels,embed_size,patch_size,patch_size)

    def forward(self,img):
        img = self.projection(img) # Now we have n features for each patch, n = embed_size
        img = img.reshape(-1,self.n_patches,self.embed_size) # bs, seq_len, embed_size
        return img

Now we will pass these sequence of embedding into the encoder. But before doing that, we will insert an **extra learnable embedding** beside patches embedding known as class token \<CLS\>. Its purpose is to gather global information from the entire image as it passes through the Transformer Encoder layers.

We will classifer the images by passing the \<CLS\> token to a classifer (MLP head)

For Transformer Encoder layer, we will use nn.TransformerEncoderLayer


<img src="https://machinelearningmastery.com/wp-content/uploads/2022/02/vit_1.png" height=500 width=800 >

In [None]:
class ViT(nn.Module):
    def __init__(
            self,
            img_size,
            patch_size,
            embed_size,
            in_channels,
            num_heads,
            num_layers,
            ff_expansion,
            dout,
            classes,
            ):
        super(ViT,self).__init__()

        self.n_patches = (img_size // patch_size) ** 2 # HW/P^2
        assert img_size % patch_size == 0, f"Image size {img_size} must be divisible by patch size {patch_size}"

        self.pos_embed = nn.Parameter(torch.randn(1, self.n_patches + 1 , embed_size)) # + 1 for CLS token
        self.cls = nn.Parameter(torch.randn(1,1,embed_size))
        self.patch_embed = PatchEmbedding(img_size,patch_size,embed_size,in_channels)
        encoder_layer = nn.TransformerEncoderLayer(
                    d_model=embed_size,
                    nhead=num_heads,
                    dim_feedforward=int(embed_size * ff_expansion),
                    dropout=dout,
                    activation='gelu', # They used gelu in ViT paper
                    batch_first=True,
                    norm_first=True  # pre-norm architecture like ViT paper
                )
        self.encoder = nn.TransformerEncoder(encoder_layer,num_layers)
        self.mlp_head = nn.Linear(embed_size,classes)


    def forward(self,img):
        img = self.patch_embed(img)
        cls = self.cls.expand(img.size(0),-1,-1) # bs,1,embed_size
        img = torch.cat([cls, img], dim=1) # add one more patch to the image (cls patch)

        # You can expand the pos_embed to match batch size, but
        # pytorch do it automatically when only one dimension is different
        img = img + self.pos_embed # this broadcasting automatically

        # Now we are ready to pass these patches to the Encoder
        img = self.encoder(img)

        # Only pass the Cls token to the output layer
        out = self.mlp_head(img[:,0,:])
        return out


# Let us test it on MNIST

In [None]:
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # MNIST mean and std
])

train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViT(
            img_size=28,
            patch_size=7,
            embed_size=510,
            in_channels=1,
            num_heads=6,
            num_layers=3,
            ff_expansion=4,
            dout=0.2,
            classes=10,

).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")
print(f"Device: {device}")


Model has 9423280 parameters
Device: cuda




In [None]:
from tqdm import tqdm
def train(model,criterion,train_loader,test_loader,num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        total = 0

        for data, target in tqdm(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

        accuracy = 100. * correct / total
        print(f'Epoch {epoch+1}: Loss: {total_loss/len(train_loader):.4f}, Accuracy: {accuracy:.2f}%')

        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in tqdm(test_loader):
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target).item()
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)

        test_accuracy = 100. * correct / total
        print(f'Test Accuracy: {test_accuracy:.2f}%')



In [None]:
train(model,criterion,train_loader,test_loader)

100%|██████████| 469/469 [00:20<00:00, 22.62it/s]


Epoch 1: Loss: 0.4906, Accuracy: 86.80%


100%|██████████| 79/79 [00:02<00:00, 32.94it/s]


Test Accuracy: 93.77%


100%|██████████| 469/469 [00:20<00:00, 22.70it/s]


Epoch 2: Loss: 0.1618, Accuracy: 95.06%


100%|██████████| 79/79 [00:02<00:00, 32.91it/s]


Test Accuracy: 95.28%


100%|██████████| 469/469 [00:20<00:00, 22.74it/s]


Epoch 3: Loss: 0.1370, Accuracy: 95.81%


100%|██████████| 79/79 [00:02<00:00, 32.93it/s]


Test Accuracy: 95.75%


100%|██████████| 469/469 [00:20<00:00, 22.91it/s]


Epoch 4: Loss: 0.1378, Accuracy: 95.79%


100%|██████████| 79/79 [00:02<00:00, 33.18it/s]


Test Accuracy: 96.51%


100%|██████████| 469/469 [00:20<00:00, 22.88it/s]


Epoch 5: Loss: 0.1205, Accuracy: 96.47%


100%|██████████| 79/79 [00:02<00:00, 33.47it/s]

Test Accuracy: 95.65%





OK now everything works fine. Let us use our own [Encoder](https://github.com/Mohamad-Atif1/paper2code/blob/main/Transformers/Transformer.ipynb) instead of using **nn.TransformerEncoderLayer**

VIT with custom Encoder

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self,embed_size,heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads


        assert (self.head_dim * heads == embed_size), "embed_size // heads is not divisible"

        self.queries = nn.Linear(self.embed_size,self.embed_size)
        self.keys = nn.Linear(self.embed_size,self.embed_size)
        self.values = nn.Linear(self.embed_size,self.embed_size)

        self.fc = nn.Linear(self.embed_size, self.embed_size)

    def forward(self,q,k,v,mask):
        q = self.queries(q)
        k = self.keys(k)
        v = self.values(v)
        # split the embedding into k heads
        # q,v,k shapes: [bs, num_of_tokens (N) , heads * head_dim ]
        # Now techincally, we will work on each head independently, as if there were batches of heads!
        # (batch_size * heads, num_tokens, head_dim)
        q = q.reshape(q.shape[0]*self.heads, q.shape[1], self.head_dim)
        k = k.reshape(k.shape[0]*self.heads, k.shape[1], self.head_dim)
        v = v.reshape(v.shape[0]*self.heads, v.shape[1], self.head_dim)

        energy = torch.bmm(q,k.permute(0,2,1)) # Q*K^T -> bs, q_N, k_N
        if mask is not None:
            energy = energy.masked_fill(mask==0,float("-1e20"))

        attention = torch.softmax(energy/(self.head_dim ** (0.5)), dim=2)
        #

        out = torch.bmm(attention,v)
        out = out.reshape(q.shape[0]//self.heads, q.shape[1], self.embed_size) # bs,N,h*heads
        out = self.fc(out)
        return out




In [None]:
class EncoderBlock(nn.Module):

    def __init__(self,embed_size,heads, ff_expantion,dout=0.1):
        super(EncoderBlock,self).__init__()
        self.attention = MultiHeadAttention(embed_size,heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size,embed_size*ff_expantion),
            nn.GELU(),
            nn.Linear(embed_size*ff_expantion,embed_size))
        self.dout = nn.Dropout(dout)


    def forward(self,q,k,v,mask):
        sub_layer_one = self.attention(q,k,v,mask)
        sub_layer_one = self.dout(sub_layer_one)
        sub_layer_one += q # skip connection
        sub_layer_one = self.norm1(sub_layer_one)

        sub_layer_two = self.feed_forward(sub_layer_one)
        sub_layer_two = self.dout(sub_layer_two)
        sub_layer_two += sub_layer_one
        sub_layer_two = self.norm2(sub_layer_two)



        return sub_layer_two


In [None]:
class ViT(nn.Module):
    def __init__(
            self,
            img_size,
            patch_size,
            embed_size,
            in_channels,
            num_heads,
            num_layers,
            ff_expansion,
            dout,
            classes,
            ):
        super(ViT,self).__init__()

        self.n_patches = (img_size // patch_size) ** 2 # HW/P^2
        assert img_size % patch_size == 0, f"Image size {img_size} must be divisible by patch size {patch_size}"

        self.pos_embed = nn.Parameter(torch.randn(1, self.n_patches + 1 , embed_size)) # + 1 for CLS token
        self.cls = nn.Parameter(torch.randn(1,1,embed_size))
        self.patch_embed = PatchEmbedding(img_size,patch_size,embed_size,in_channels)
        self.layers = nn.ModuleList()

        for i in range(num_layers):
            self.layers.append(EncoderBlock(embed_size,num_heads, ff_expansion,dout=0.1))
        self.mlp_head = nn.Linear(embed_size,classes)


    def forward(self,img):
        img = self.patch_embed(img)
        cls = self.cls.expand(img.size(0),-1,-1) # bs,1,embed_size
        img = torch.cat([cls, img], dim=1) # add one more patch to the image (cls patch)

        # You can expand the pos_embed to match batch size, but
        # pytorch do it automatically when only one dimension is different
        img = img + self.pos_embed # this broadcasting automatically

        # Now we are ready to pass these patches to the Encoder
        for layer in self.layers:
            img = layer(img,img,img,None)

        # Only pass the Cls token to the output layer
        out = self.mlp_head(img[:,0,:])
        return out


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViT(
            img_size=28,
            patch_size=7,
            embed_size=510,
            in_channels=1,
            num_heads=6,
            num_layers=3,
            ff_expansion=4,
            dout=0.2,
            classes=10,

).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")
print(f"Device: {device}")

Model has 9423280 parameters
Device: cuda


In [None]:
from tqdm import tqdm
def train(model,criterion,train_loader,test_loader,num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        total = 0

        for data, target in tqdm(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

        accuracy = 100. * correct / total
        print(f'Epoch {epoch+1}: Loss: {total_loss/len(train_loader):.4f}, Accuracy: {accuracy:.2f}%')

        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in tqdm(test_loader):
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target).item()
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)

        test_accuracy = 100. * correct / total
        print(f'Test Accuracy: {test_accuracy:.2f}%')



In [None]:
train(model,criterion,train_loader,test_loader)

100%|██████████| 469/469 [00:19<00:00, 23.63it/s]


Epoch 1: Loss: 0.5213, Accuracy: 84.37%


100%|██████████| 79/79 [00:02<00:00, 33.86it/s]


Test Accuracy: 92.42%


100%|██████████| 469/469 [00:19<00:00, 23.70it/s]


Epoch 2: Loss: 0.5855, Accuracy: 82.64%


100%|██████████| 79/79 [00:02<00:00, 34.07it/s]


Test Accuracy: 92.06%


100%|██████████| 469/469 [00:19<00:00, 23.86it/s]


Epoch 3: Loss: 0.2794, Accuracy: 91.96%


100%|██████████| 79/79 [00:02<00:00, 34.10it/s]


Test Accuracy: 91.81%


100%|██████████| 469/469 [00:19<00:00, 23.93it/s]


Epoch 4: Loss: 0.2376, Accuracy: 93.33%


100%|██████████| 79/79 [00:02<00:00, 34.36it/s]


Test Accuracy: 94.06%


100%|██████████| 469/469 [00:19<00:00, 23.88it/s]


Epoch 5: Loss: 0.2097, Accuracy: 94.01%


100%|██████████| 79/79 [00:02<00:00, 34.14it/s]

Test Accuracy: 95.03%





It seems that our custom encoder works perfectly! ✅

---


**REF**

[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)