# Import required libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import random

In [None]:
device="cuda" if torch.cuda.is_available() else "cpu"

In [None]:
#To achieve consistent results and increase reproducibility
torch.manual_seed(3)
torch.cuda.manual_seed(3)
random.seed(3)

Setting the hyperparameters


In [None]:
BATCH_SIZE=128
EPOCHS=10
LEARNING_RATE=3e-4
PATCH_SIZE=4
NUM_CLASSES=10
IMAGE_SIZE=32
CHANNELS=3
EMBED_DIM=256
NUM_HEADS=8
DEPTH=6
MLP_DIM=512
DROP_RATE=0.1

In [None]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5),(0.5))#helps converge faster
])

In [None]:
#loading the train dataset
train_dataset = datasets.CIFAR10(root="data",
                                 train=True,
                                 download=True,
                                 transform=transform)

In [None]:
#loading test dataset
test_dataset= datasets.CIFAR10(root="data",
                                 train=False,
                                 download=True,
                                 transform=transform)

In [None]:
'''
Loading pytorch datasets and converting them to mini-batches to be
more computationally efficient
'''
train_loader= DataLoader(dataset=train_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=True)
test_loader= DataLoader(dataset=test_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=False)

In [None]:
# Understanding the data
print(f"Lenght of train_loader: {len(train_loader)}")
print(f"Lenght of test_loader: {len(test_loader)}")

##Building vision transformer model from scratch

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self,
                 img_size,
                 patch_size,
                 in_channels,
                 embed_dim):
        super().__init__()
        self.patch_size=patch_size
        self.proj=nn.Conv2d(in_channels=in_channels,
                            out_channels=embed_dim,
                            kernel_size=patch_size,
                            stride=patch_size)
        num_patches=(img_size//patch_size)**2
        self.cls_token=nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embed=nn.Parameter(torch.randn(1,1+num_patches,embed_dim))

    def forward(self,x:torch.Tensor):
        B=x.size(0)
        x=self.proj(x) #(B,E,H/P,W/P)
        x=x.flatten(2)
        x=x.transpose(1,2)
        cls_token=self.cls_token.expand(B,-1,-1)
        x=torch.cat((cls_token,x),dim=1)
        x=x+self.pos_embed
        return x


In [None]:
class MLP(nn.Module):
    def __init__(self,
                 in_features,
                 hidden_features,
                 drop_rate):
        super().__init__()
        self.fc1=nn.Linear(in_features=in_features,
                            out_features=hidden_features)
        self.fc2=nn.Linear(in_features=hidden_features,
                            out_features=in_features)
        self.dropout=nn.Dropout(drop_rate)

    def forward(self,x):
        x=self.dropout(F.gelu(self.fc1(x)))
        x=self.dropout((self.fc2(x)))
        return x

In [None]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self,embed_dim,num_heads,mlp_dim,drop_rate):
        super().__init__()
        self.norm1=nn.LayerNorm(embed_dim)
        self.attn=nn.MultiheadAttention(embed_dim,num_heads,dropout=drop_rate,batch_first=True)
        self.norm2=nn.LayerNorm(embed_dim)
        self.mlp=MLP(in_features=embed_dim,hidden_features=mlp_dim,drop_rate=drop_rate)

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self,img_size,patch_size,in_channels, num_classes,embed_dim,depth,num_heads,mlp_dim,drop_rate):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size,patch_size,in_channels,embed_dim)
        self.encoder = nn.Sequential(*[
            TransformerEncoderLayer(embed_dim,num_heads,mlp_dim,drop_rate)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.encoder(x)
        x = self.norm(x)
        cls_token = x[:, 0]
        return self.head(cls_token)

In [None]:
#Instantiate model
model = VisionTransformer(
    IMAGE_SIZE, PATCH_SIZE, CHANNELS, NUM_CLASSES,
    EMBED_DIM, DEPTH, NUM_HEADS, MLP_DIM, DROP_RATE
).to(device)

In [None]:
#Defining a loss function and optimiser
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(),
                              lr=LEARNING_RATE)


In [None]:
#Defining a training loop
def train(model,loader,optimizer,criterion):
    model.train()

    total_loss,correct=0,0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        #Forward pass
        out = model(x)
        #Calculate loss
        loss = criterion(out, y)
        #back-prop
        loss.backward()
        #Perform Gradient Descent
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        correct += (out.argmax(1) == y).sum().item()

    return total_loss / len(loader.dataset), correct / len(loader.dataset)

In [None]:
def evaluate(model ,loader):
    model.eval() #Set the mode of the model in evaluation mode
    correct = 0
    with torch.inference_mode():
        for x, y in loader:
            x, y =x.to(device), y.to(device)
            out = model(x)
            correct += (out.argmax(1) == y).sum().item()
    return correct/len(loader.dataset)

#Training

In [None]:
from tqdm.auto import tqdm

In [None]:
train_accuracies,test_accuracies= [], []

for epoch in tqdm(range(EPOCHS)):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion)
    test_acc = evaluate(model, test_loader)
    train_accuracies.append(train_acc)
    test_accuracies.append(test_acc)
    print(f"Epoch: {epoch+1}/{EPOCHS}, Train loss: {train_loss:.4f}, Train acc: {train_acc:.4f}%, Test acc: {test_acc:.4f}")


In [None]:
# Plot accuracy
plt.plot(train_accuracies, label="Train Accuracy")
plt.plot(test_accuracies, label="Test Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Training and Test Accuracy")
plt.show()