#### About
Image classification using ViT in PyTorch
Dataset link - https://www.kaggle.com/datasets/gpiosenka/100-bird-species

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
os.chdir('/content/drive/MyDrive/Datasets/')
!unzip archive.zip

In [None]:
#importing modules
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms
!pip install torchinfo --quiet
from torchinfo import summary
from PIL import Image, ImageEnhance
import numpy as np
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
%matplotlib inline


In [None]:
#dataset path
train_dir = "/content/drive/MyDrive/Datasets/train/"
val_dir= "/content/drive/MyDrive/Datasets/valid/"

In [None]:
#image enhancement function while training
enhancers = {
    0: lambda image, f: ImageEnhance.Color(image).enhance(f),
    1: lambda image, f: ImageEnhance.Contrast(image).enhance(f),
    2: lambda image, f: ImageEnhance.Brightness(image).enhance(f),
    3: lambda image, f: ImageEnhance.Sharpness(image).enhance(f)
}

factors = {
        0: lambda: np.random.normal(1.0, 0.3),
        1: lambda: np.random.normal(1.0, 0.1),
        2: lambda: np.random.normal(1.0, 0.1),
        3: lambda: np.random.normal(1.0, 0.3),
    }
    

def enhance(image):
    order = [0, 1, 2, 3]
    np.random.shuffle(order)
    for i in order:
        f = factors[i]()
        image = enhancers[i](image, f)
    return image

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((224,224),Image.LANCZOS),
    
    transforms.RandomHorizontalFlip(),
    transforms.Lambda(enhance),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((224,224),Image.LANCZOS),
    transforms.ToTensor()
])

train_data = ImageFolder(train_dir, transform=train_transform)
val_data = ImageFolder(val_dir,transform=val_transform)

In [None]:
train_data.__getitem__(5)

In [None]:
#creating dataloader
batch_size =64
train_loader = DataLoader(train_data,batch_size,shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_data, batch_size,shuffle=True,  num_workers=4, pin_memory=True)

In [None]:
for i, (inputs, labels) in enumerate(train_loader):
    print(inputs.shape,labels.shape)
    break

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
num_classes = len(train_data.classes)
print(num_classes)

In [None]:
!pip install timm
model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
#freezing model
for param in model.parameters():
    param.requires_grad = False

n_inputs = model.head.in_features
model.head = nn.Sequential(
    nn.Linear(n_inputs, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, len(num_classes))
)
model = model.to(device)
print(model.head)

In [None]:
summary(model,input_size=(1,3,224,224))

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()


In [None]:
def fit(model, loss_criterion, optimizer, epochs=25):

    history = []
    best_acc = 0.0

    for epoch in range(epochs):
        print("Epoch: {}/{}".format(epoch+1, epochs))
        
        # Set to training mode
        model.train()
        
        # Loss and Accuracy within the epoch
        train_loss = 0.0
        train_acc = 0.0
        
        valid_loss = 0.0
        valid_acc = 0.0
        
        for i, (inputs, labels) in enumerate(train_loader):

            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Clean existing gradients
            optimizer.zero_grad()
            
            # Forward pass - compute outputs on input data using the model
            outputs = model(inputs)
            
            # Compute loss
            loss = loss_criterion(outputs, labels)
            
            # Backpropagate the gradients
            loss.backward()
            
            # Update the parameters
            optimizer.step()
            
            # Compute the total loss for the batch and add it to train_loss
            train_loss += loss.item() * inputs.size(0)
            
            # Compute the accuracy
            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))
            
            # Convert correct_counts to float and then compute the mean
            acc = torch.mean(correct_counts.type(torch.FloatTensor))
            
            # Compute total accuracy in the whole batch and add to train_acc
            train_acc += acc.item() * inputs.size(0)
            

            
        # Validation - No gradient tracking needed
        with torch.no_grad():

            # Set to evaluation mode
            model.eval()

            # Validation loop
            for j, (inputs, labels) in enumerate(val_loader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Forward pass - compute outputs on input data using the model
                outputs = model(inputs)

                # Compute loss
                loss = loss_criterion(outputs, labels)

                # Compute the total loss for the batch and add it to valid_loss
                valid_loss += loss.item() * inputs.size(0)

                # Calculate validation accuracy
                ret, predictions = torch.max(outputs.data, 1)
                correct_counts = predictions.eq(labels.data.view_as(predictions))

                # Convert correct_counts to float and then compute the mean
                acc = torch.mean(correct_counts.type(torch.FloatTensor))

                # Compute total accuracy in the whole batch and add to valid_acc
                valid_acc += acc.item() * inputs.size(0)

            
        # Find average training loss and training accuracy
        avg_train_loss = train_loss/len(train_data) 
        avg_train_acc = train_acc/len(train_data) 

        # Find average training loss and training accuracy
        avg_valid_loss = valid_loss/len(val_data)  
        avg_valid_acc = valid_acc/len(val_data) 

        history.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc])
                
    
        print("Epoch : {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation : Loss : {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(epoch+1, avg_train_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100))
        #model.load_state_dict(best_model_wts)
            
    return model, history

In [None]:
model,history = fit(model, criterion,optimizer,10)

In [None]:
torch.save(model.state_dict(), 'ViT.pth')


In [None]:
history = np.array(history)
plt.plot(history[:,0:2])
plt.legend(['Training Loss', 'Val Loss'])
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.ylim(0,1)
plt.show()

In [None]:
plt.plot(history[:,2:4])
plt.legend(['Training Accuracy', 'Val Accuracy'])
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim(0,1)
plt.show()

In [None]:
idx_to_class = {v: k for k, v in train_data.class_to_idx.items()}
print(idx_to_class)


def predict(model, test_image_name):
    
    test_image = Image.open(test_image_name).convert('RGB')
    print(np.shape(test_image))
    plt.imshow(test_image)
 
    transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor() 
          ])
    img_normalized = transform(test_image).float()

    if torch.cuda.is_available():
        test_image_tensor = img_normalized.view(1, 3, 224, 224).cuda()
    else:
        test_image_tensor = img_normalized.view(1, 3, 224, 224)
    
    with torch.no_grad():
        model.eval()
        # Model outputs log probabilities
        out = model(test_image_tensor)
        prob = torch.exp(out)
        prob_, class_ = prob.topk(3, dim=1)
        class_ = class_.cpu().numpy()
        for i in range(3):
            print("Predcition", i+1, ":", idx_to_class[class_[0][i]], ", Score: ", prob_.cpu().numpy()[0][i])