In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import dataloader
from torchvision import datasets, transforms
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import random
from tqdm.auto import tqdm

device='cuda' if torch.cuda.is_available() else 'cpu'


In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)

Hyparameters


In [None]:
Batch_size=512
Epochs=100
Learning_rate=3e-4
Patch_size=4
Num_classes=10
img_size=32
Channels=3
Embed_dim=384
Num_heads=8
Depth=6
Mlp_dim=512
Drop_rate=0.1

Image augmentation:

In [None]:
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(img_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

transform_test = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

Dataset[CIFAR-10]

In [None]:
train_dataset=datasets.CIFAR10(root='./data',train=True,download=True,transform=transform_train)
test_dataset=datasets.CIFAR10(root='./data',train=False,download=True,transform=transform_test)

In [None]:
train_dataloader  =torch.utils.data.DataLoader(train_dataset,batch_size=Batch_size,num_workers=8,shuffle=True)
test_dataloader =torch.utils.data.DataLoader(test_dataset,batch_size=Batch_size,num_workers=8, shuffle=False)



ViT

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self,img_size,patch_size,in_channels,embed_dim):
        super().__init__()
        self.patch_size=patch_size

        ## for non overlapping patches:
        self.proj=nn.Conv2d(in_channels,embed_dim,kernel_size=patch_size,stride=patch_size)
        num_patches=(img_size//patch_size)**2

        ## overlapping patches:
        #self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size // 2)
        #output_height = (img_size - patch_size) // (patch_size // 2) + 1
        #output_width = (img_size - patch_size) // (patch_size // 2) + 1
        #num_patches = output_height * output_width
        ##
        self.cls_token=nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embed=nn.Parameter(torch.randn(1,num_patches+1,embed_dim))
    def forward(self,x):
        B=x.size(0)
        x=self.proj(x)
        x=x.flatten(2).transpose(1,2)
        cls_token=self.cls_token.expand(B,-1,-1)
        x=torch.cat((cls_token,x),dim=1)
        x=x+self.pos_embed
        return x

class MLP(nn.Module):
    def __init__(self,in_features,hidden_features,drop_rate):
        super().__init__()
        self.fc1=nn.Linear(in_features,hidden_features)
        self.fc2=nn.Linear(hidden_features,in_features)
        self.drop=nn.Dropout(drop_rate)
    def forward(self,x):
        x=self.drop(F.gelu(self.fc1(x)))
        x=self.drop(self.fc2(x))
        return x
class TransformerEncoderBlock(nn.Module):
    def __init__(self,embed_dim,num_heads,mlp_dim,drop_rate):
        super().__init__()
        self.norm1=nn.LayerNorm(embed_dim)
        self.attn=nn.MultiheadAttention(embed_dim,num_heads,dropout=drop_rate, batch_first= True)
        self.norm2=nn.LayerNorm(embed_dim)
        self.mlp=MLP(embed_dim,mlp_dim,drop_rate)
    def forward(self,x):
        x=x+self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x=x+self.mlp(self.norm2(x))
        return x

In [None]:
class VisionTransformer(nn.Module):
  def __init__(self, img_size, patch_size, in_channels, num_classes, embed_dim, depth, num_heads, mlp_dim,drop_rate):
      super().__init__()
      self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
      self.encoder =nn.Sequential(*[TransformerEncoderBlock(embed_dim,num_heads,mlp_dim,drop_rate)
       for _ in range(depth)])
      self.norm = nn.LayerNorm(embed_dim)
      self.head = nn.Linear(embed_dim, num_classes)
  def forward(self, x):
     x = self.patch_embed(x)
     x = self.encoder(x)
     x = self.norm(x)
     cls_token = x[:,0]
     return self.head(cls_token)

Initiation of the model

In [None]:
model = VisionTransformer(img_size, Patch_size, Channels, Num_classes, Embed_dim, Depth, Num_heads, Mlp_dim, Drop_rate).to(device)
model

VisionTransformer(
  (patch_embed): PatchEmbedding(
    (proj): Conv2d(3, 384, kernel_size=(4, 4), stride=(4, 4))
  )
  (encoder): Sequential(
    (0): TransformerEncoderBlock(
      (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True)
      )
      (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=384, out_features=512, bias=True)
        (fc2): Linear(in_features=512, out_features=384, bias=True)
        (drop): Dropout(p=0.1, inplace=False)
      )
    )
    (1): TransformerEncoderBlock(
      (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True)
      )
      (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
 

LOSS and Optimiser

In [None]:
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Epochs)

In [None]:
def train(model,dataloader,criterion,optimizer,device):
    model.train()
    total_loss, correct =0 ,0
    for x,y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output,y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()*x.size(0)
        correct += (output.argmax(1)==y).sum().item()
    return total_loss/len(dataloader.dataset), correct/len(dataloader.dataset)
def evaluate(model,dataloader,criterion,device):
    model.eval()
    total_loss, correct =0 ,0
    with torch.inference_mode():
      for x,y in dataloader:
        x, y = x.to(device), y.to(device)
        output = model(x)
        correct += (output.argmax(dim=1)==y).sum().item()
    return correct/len(dataloader.dataset)

In [None]:
best_acc = 0.0
patience = 5
counter = 0
train_accuracies, test_accuracies = [], []
for epoch in tqdm(range(Epochs)):
    train_loss, train_acc = train(model, train_dataloader, criterion, optimizer, device)
    test_acc = evaluate(model, test_dataloader, criterion, device)
    train_accuracies.append(train_acc)
    test_accuracies.append(test_acc)
    print(f"Epochs: {epoch+1}/{Epochs},Train Loss:{train_loss:.4f}, Train acc: {train_acc:.4f}, Test acc: {test_acc:.4f}")
    scheduler.step()
    if test_acc > best_acc:
      best_acc = test_acc
      counter = 0
      torch.save(model.state_dict(), "best_model.pth")
    else:
      counter += 1
      if counter >= patience:
            print("Early stopping triggered.")
            break
model.load_state_dict(torch.load("best_model.pth"))


  0%|          | 0/100 [00:00<?, ?it/s]

Epochs: 1/100,Train Loss:1.9572, Train acc: 0.3166, Test acc: 0.4407
Epochs: 2/100,Train Loss:1.6888, Train acc: 0.4533, Test acc: 0.5078
Epochs: 3/100,Train Loss:1.5946, Train acc: 0.4977, Test acc: 0.5404
Epochs: 4/100,Train Loss:1.5215, Train acc: 0.5313, Test acc: 0.5658
Epochs: 5/100,Train Loss:1.4597, Train acc: 0.5605, Test acc: 0.6011
Epochs: 6/100,Train Loss:1.4145, Train acc: 0.5838, Test acc: 0.6103
Epochs: 7/100,Train Loss:1.3747, Train acc: 0.6016, Test acc: 0.6214
Epochs: 8/100,Train Loss:1.3438, Train acc: 0.6164, Test acc: 0.6414
Epochs: 9/100,Train Loss:1.3091, Train acc: 0.6315, Test acc: 0.6564
Epochs: 10/100,Train Loss:1.2870, Train acc: 0.6425, Test acc: 0.6702
Epochs: 11/100,Train Loss:1.2612, Train acc: 0.6520, Test acc: 0.6757
Epochs: 12/100,Train Loss:1.2318, Train acc: 0.6676, Test acc: 0.6904
Epochs: 13/100,Train Loss:1.2154, Train acc: 0.6764, Test acc: 0.6919
Epochs: 14/100,Train Loss:1.1932, Train acc: 0.6868, Test acc: 0.6995
Epochs: 15/100,Train Loss:1.1

<All keys matched successfully>