#  Importing libraries

In [None]:
import torch
import torchvision
from torch import nn

In [None]:
from torchvision import datasets
from torch.utils.data import DataLoader

import path
import numpy as np
import os
import matplotlib.pyplot as plt
from PIL import Image

# Setting up device

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

# Importing Pretrained model - ViT

In [None]:
vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
vit_model = torchvision.models.vit_b_16(weights = vit_weights).to(device)

# Downloading "summary" for architecture visualization

In [None]:
try:
    from torchinfo import summary
except:
    !pip install summary
    from torchinfo import summary

In [None]:
def summary_vit():
    return summary(model=vit_model, 
         input_size=(32, 3, 224, 224), 
         col_names=["input_size", "output_size", "num_params", "trainable"],
         col_width=20,
         row_settings=["var_names"])

# Setting up transformers

In [None]:
vit_transform = vit_weights.transforms()
vit_transform

In [None]:
train_dir = "/kaggle/input/the-simpsons-characters-dataset/simpsons_dataset"
test_dir = "/kaggle/input/the-simpsons-characters-dataset/kaggle_simpson_testset"

In [None]:
train_data = datasets.ImageFolder(root = train_dir , transform = vit_transform)
test_data = datasets.ImageFolder(root = test_dir , transform = vit_transform)

In [None]:
class_names = train_data.classes
len(class_names)

# Freezing pretrained Layers

In [None]:
for params in vit_model.parameters():
    params.requires_grad = False

vit_model.heads = nn.Linear(in_features = 768 , out_features = len(class_names)).to(device)

summary_vit()

In [None]:
for names in class_names:
    print(f"{names}  ",end = "")

In [None]:
plt.imshow(test_data[0][0].permute(1,2,0))
plt.axis("off")
plt.title(train_data.classes[train_data[0][1]])

# Setting up DataLoader

In [None]:
batch_size = 32
num_workers = os.cpu_count()

train_dataloader = DataLoader(dataset = train_data , batch_size = batch_size , 
                              shuffle = True , num_workers = num_workers  , pin_memory = True)
test_dataloader = DataLoader(dataset = test_data , batch_size = batch_size , 
                             shuffle = False , num_workers = num_workers  , pin_memory = True)

train_dataloader , test_dataloader

In [None]:
for batch, (X, y) in enumerate(train_dataloader):
    if batch == 2:
        break
    else:
        """print(f"{X}    {y}")"""
        
class_names[torch.argmax(vit_model(X.to(device)).argmax(dim=1))]

In [None]:
class_names[y.argmax()]

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=vit_model.parameters(), 
                             lr=3e-3,
                             betas=(0.9, 0.999),
                             weight_decay=0.3)

# integrating Train step and Test step into a single function

In [None]:
def train_step(model : torch.nn.Module , dataloader : torch.utils.data.DataLoader , loss_fn : torch.nn.Module , 
               optimizer : torch.optim.Optimizer , device : torch.device) -> tuple[float , float]:
    
    model.train()
    
    train_loss , train_acc = 0 , 0
    
    for batch , (X,y) in enumerate(dataloader):
        
        X , y = X.to(device) , y.to(device)
        y_pred = model(X)
        
        loss = loss_fn(y_pred , y)
        train_loss += loss.item()
        
        optimizer.zero_grad()
        
        loss.backward()
        
        optimizer.step()
        
        y_pred_class = torch.argmax(torch.softmax(y_pred , dim = 1) , dim = 1)
        train_acc += (y_pred_class == y).sum().item() / len(y_pred)
        
    train_loss = train_loss/len(dataloader)
    train_acc = train_acc/len(dataloader)
    
    return train_loss , train_acc

In [None]:
def test_step(model : torch.nn.Module , dataloader : torch.utils.data.DataLoader,loss_fn : torch.nn.Module , 
              device = torch.device) -> tuple[float , float]:
    
    model.eval()
    test_loss , test_acc = 0 , 0
    
    with torch.inference_mode():
        for batch , (X , y) in enumerate(dataloader):
            
            X , y = X.to(device) , y.to(device)
            y_pred = model(X)
            
            loss = loss_fn(y_pred , y)
            test_loss += loss.item()
            
            y_pred_class = torch.argmax(torch.softmax(y_pred , dim = 1) , dim = 1)
            test_acc += (y_pred_class == y).sum().item() / len(y_pred)
            
    test_loss = test_loss/len(dataloader)
    test_acc = test_acc/len(dataloader)
    
    return test_loss , test_acc

In [None]:
def train(model : torch.nn.Module , train_dataloader : torch.utils.data.DataLoader , test_dataloader : torch.utils.data.DataLoader , 
         loss_fn : torch.nn.Module , optimizer : torch.optim.Optimizer , device : torch.device , epochs : int)  -> dict[str, list]:
    
    from tqdm import tqdm
    
    results = {"train_loss": [],
               "train_acc": [],
               "test_loss": [],
               "test_acc": [] }
    for epoch in tqdm(range(epochs)):
      train_loss, train_acc = train_step(model=model,
                                          dataloader=train_dataloader,
                                          loss_fn=loss_fn,
                                          optimizer=optimizer,
                                          device=device)
      test_loss, test_acc = test_step(model=model,
          dataloader=test_dataloader,
          loss_fn=loss_fn,
          device=device)
        
      print(          
          f"Epoch: {epoch+1} | "
          f"train_loss: {train_loss:.4f} | "
          f"train_acc: {train_acc:.4f} | "
          f"test_loss: {test_loss:.4f} | "
          f"test_acc: {test_acc:.4f}"
        )


      results["train_loss"].append(train_loss)
      results["train_acc"].append(train_acc)
      results["test_loss"].append(test_loss)
      results["test_acc"].append(test_acc)

  
    return results

# Training and evaluating

In [None]:
results = train(model = vit_model , train_dataloader = train_dataloader , test_dataloader = test_dataloader , 
                loss_fn = loss_fn , optimizer = optimizer , device = device , epochs = 5)