#### About
Image classification using ViT in PyTorch
Dataset link - https://www.kaggle.com/datasets/gpiosenka/100-bird-species

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os
os.chdir('/content/drive/MyDrive/Datasets/')
!unzip archive.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: train/WHITE CHEEKED TURACO/080.jpg  
  inflating: train/WHITE CHEEKED TURACO/081.jpg  
  inflating: train/WHITE CHEEKED TURACO/082.jpg  
  inflating: train/WHITE CHEEKED TURACO/083.jpg  
  inflating: train/WHITE CHEEKED TURACO/084.jpg  
  inflating: train/WHITE CHEEKED TURACO/085.jpg  
  inflating: train/WHITE CHEEKED TURACO/086.jpg  
  inflating: train/WHITE CHEEKED TURACO/087.jpg  
  inflating: train/WHITE CHEEKED TURACO/088.jpg  
  inflating: train/WHITE CHEEKED TURACO/089.jpg  
  inflating: train/WHITE CHEEKED TURACO/090.jpg  
  inflating: train/WHITE CHEEKED TURACO/091.jpg  
  inflating: train/WHITE CHEEKED TURACO/092.jpg  
  inflating: train/WHITE CHEEKED TURACO/093.jpg  
  inflating: train/WHITE CHEEKED TURACO/094.jpg  
  inflating: train/WHITE CHEEKED TURACO/095.jpg  
  inflating: train/WHITE CHEEKED TURACO/096.jpg  
  inflating: train/WHITE CHEEKED TURACO/097.jpg  
  inflating: train/WHITE CHEEKED TU

In [1]:
#importing modules
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms
!pip install torchinfo --quiet
from torchinfo import summary
from PIL import Image, ImageEnhance
import numpy as np
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
%matplotlib inline


In [2]:
#dataset path
train_dir = "/content/drive/MyDrive/Datasets/train/"
val_dir= "/content/drive/MyDrive/Datasets/valid/"

In [3]:
#image enhancement function while training
enhancers = {
    0: lambda image, f: ImageEnhance.Color(image).enhance(f),
    1: lambda image, f: ImageEnhance.Contrast(image).enhance(f),
    2: lambda image, f: ImageEnhance.Brightness(image).enhance(f),
    3: lambda image, f: ImageEnhance.Sharpness(image).enhance(f)
}

factors = {
        0: lambda: np.random.normal(1.0, 0.3),
        1: lambda: np.random.normal(1.0, 0.1),
        2: lambda: np.random.normal(1.0, 0.1),
        3: lambda: np.random.normal(1.0, 0.3),
    }
    

def enhance(image):
    order = [0, 1, 2, 3]
    np.random.shuffle(order)
    for i in order:
        f = factors[i]()
        image = enhancers[i](image, f)
    return image

In [4]:
train_transform = transforms.Compose([
    transforms.Resize((224,224),Image.LANCZOS),
    
    transforms.RandomHorizontalFlip(),
    transforms.Lambda(enhance),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((224,224),Image.LANCZOS),
    transforms.ToTensor()
])

train_data = ImageFolder(train_dir, transform=train_transform)
val_data = ImageFolder(val_dir,transform=val_transform)



In [5]:
train_data.__getitem__(5)

(tensor([[[0.7176, 0.7255, 0.7294,  ..., 0.5804, 0.5569, 0.5373],
          [0.7137, 0.7176, 0.7255,  ..., 0.5882, 0.5647, 0.5608],
          [0.7059, 0.7137, 0.7176,  ..., 0.5961, 0.5882, 0.5725],
          ...,
          [0.3804, 0.3765, 0.3686,  ..., 0.3725, 0.3961, 0.3922],
          [0.4314, 0.4314, 0.3843,  ..., 0.3647, 0.3804, 0.3843],
          [0.4000, 0.3961, 0.3843,  ..., 0.3804, 0.3843, 0.3686]],
 
         [[0.7294, 0.7333, 0.7412,  ..., 0.6353, 0.6196, 0.6000],
          [0.7255, 0.7294, 0.7333,  ..., 0.6353, 0.6235, 0.6118],
          [0.7176, 0.7255, 0.7294,  ..., 0.6549, 0.6392, 0.6392],
          ...,
          [0.3843, 0.3804, 0.3686,  ..., 0.4392, 0.4667, 0.4588],
          [0.4314, 0.4314, 0.3843,  ..., 0.4353, 0.4431, 0.4471],
          [0.4039, 0.4000, 0.3843,  ..., 0.4471, 0.4510, 0.4392]],
 
         [[0.6235, 0.6314, 0.6353,  ..., 0.3961, 0.3804, 0.3490],
          [0.6196, 0.6235, 0.6314,  ..., 0.4000, 0.3804, 0.3569],
          [0.6118, 0.6196, 0.6235,  ...,

In [6]:
#creating dataloader
batch_size =64
train_loader = DataLoader(train_data,batch_size,shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_data, batch_size,shuffle=True,  num_workers=4, pin_memory=True)



In [7]:
for i, (inputs, labels) in enumerate(train_loader):
    print(inputs.shape,labels.shape)
    break

torch.Size([64, 3, 224, 224]) torch.Size([64])


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
num_classes = len(train_data.classes)
print(num_classes)

500


In [10]:
#Creating a class for extracting embeddings via Patches of images
class ImagePatcher(nn.Module):
    def __init__(self,input_channels=3, patch_size=16, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        # creating the embedding layers
        self.img_cropper_layer = nn.Conv2d(in_channels=input_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, padding=0)
        self.linear_layer = nn.Flatten(start_dim=2, end_dim=3)

    def forward(self,x):
        image_res = x.shape[-1]
        assert image_res% self.patch_size==0 #divisible  by 16

        cropped_img_features = self.img_cropper_layer(x)
        flattened_img_features = self.linear_layer(cropped_img_features)
        flattened_img_features = flattened_img_features.permute(0,2,1)# as per pytorch

        return flattened_img_features

In [11]:
#defining transformer encoder layer and stacking 12 encoder layers
encoder_layer = nn.TransformerEncoderLayer(d_model=768,nhead=12, dim_feedforward=2048, dropout=0.1, activation='gelu', batch_first=True, norm_first=True)
encoder = nn.TransformerEncoder(encoder_layer=encoder_layer,num_layers=12)


In [12]:
#defining ViT class
class ViT(nn.Module):
    def __init__(self, img_size=224, num_channel=3, patch_size=16, embed_dim=768, p=0.1, num_layers=12,num_heads=12, hidden_dim=2048,num_classes=500):
        super().__init__()
        assert img_size%patch_size==0
        #embedding
        self.cropping_layer = ImagePatcher(input_channels=num_channel,patch_size=patch_size, embed_dim=embed_dim)

        # class token - *
        self.classtoken = nn.Parameter(torch.randn(1,1,embed_dim), requires_grad=True)

        #positional embedding for each
        num_patches = (img_size*img_size) //patch_size**2

        self.positional_embedding = nn.Parameter(torch.randn(1,num_patches+1, embed_dim))

        #dropout
        self.dropout = nn.Dropout(p=0.1) 

        #encoder 
        self.encoders = nn.TransformerEncoder(encoder_layer= nn.TransformerEncoderLayer(d_model=768,nhead=12, dim_feedforward=2048, dropout=0.1, activation='gelu', batch_first=True, norm_first=True), num_layers=num_layers)

        #mlp
        self.mlp = nn.Sequential(nn.LayerNorm(normalized_shape=embed_dim),
                                 nn.Linear(in_features=embed_dim,out_features=num_classes))
        
    
    def forward(self,x):
        batch_size = x.shape[0]
        x = self.cropping_layer(x)
        cls_token = self.classtoken.expand(batch_size,-1,-1)
        x = torch.cat((cls_token,x), dim=1)
        x = self.positional_embedding + x #similar to NLP [CLS] [0][f1] [1][f2]
        x = self.dropout(x)
        x = self.encoders(x)
        x = self.mlp(x[:,0])

        return x


In [13]:
model = ViT(num_classes=len(train_data.classes)).to(device)
model

ViT(
  (cropping_layer): ImagePatcher(
    (img_cropper_layer): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (linear_layer): Flatten(start_dim=2, end_dim=3)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (encoders): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out

In [14]:
summary(model,input_size=(1,3,224,224))

Layer (type:depth-idx)                        Output Shape              Param #
ViT                                           [1, 500]                  152,064
├─ImagePatcher: 1-1                           [1, 196, 768]             --
│    └─Conv2d: 2-1                            [1, 768, 14, 14]          590,592
│    └─Flatten: 2-2                           [1, 768, 196]             --
├─Dropout: 1-2                                [1, 197, 768]             --
├─TransformerEncoder: 1-3                     [1, 197, 768]             --
│    └─ModuleList: 2-3                        --                        --
│    │    └─TransformerEncoderLayer: 3-1      [1, 197, 768]             5,513,984
│    │    └─TransformerEncoderLayer: 3-2      [1, 197, 768]             5,513,984
│    │    └─TransformerEncoderLayer: 3-3      [1, 197, 768]             5,513,984
│    │    └─TransformerEncoderLayer: 3-4      [1, 197, 768]             5,513,984
│    │    └─TransformerEncoderLayer: 3-5      [1, 197, 76

In [15]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()


In [16]:
def fit(model, loss_criterion, optimizer, epochs=25):

    history = []
    best_acc = 0.0

    for epoch in range(epochs):
        print("Epoch: {}/{}".format(epoch+1, epochs))
        
        # Set to training mode
        model.train()
        
        # Loss and Accuracy within the epoch
        train_loss = 0.0
        train_acc = 0.0
        
        valid_loss = 0.0
        valid_acc = 0.0
        
        for i, (inputs, labels) in enumerate(train_loader):

            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Clean existing gradients
            optimizer.zero_grad()
            
            # Forward pass - compute outputs on input data using the model
            outputs = model(inputs)
            
            # Compute loss
            loss = loss_criterion(outputs, labels)
            
            # Backpropagate the gradients
            loss.backward()
            
            # Update the parameters
            optimizer.step()
            
            # Compute the total loss for the batch and add it to train_loss
            train_loss += loss.item() * inputs.size(0)
            
            # Compute the accuracy
            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))
            
            # Convert correct_counts to float and then compute the mean
            acc = torch.mean(correct_counts.type(torch.FloatTensor))
            
            # Compute total accuracy in the whole batch and add to train_acc
            train_acc += acc.item() * inputs.size(0)
            

            
        # Validation - No gradient tracking needed
        with torch.no_grad():

            # Set to evaluation mode
            model.eval()

            # Validation loop
            for j, (inputs, labels) in enumerate(val_loader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Forward pass - compute outputs on input data using the model
                outputs = model(inputs)

                # Compute loss
                loss = loss_criterion(outputs, labels)

                # Compute the total loss for the batch and add it to valid_loss
                valid_loss += loss.item() * inputs.size(0)

                # Calculate validation accuracy
                ret, predictions = torch.max(outputs.data, 1)
                correct_counts = predictions.eq(labels.data.view_as(predictions))

                # Convert correct_counts to float and then compute the mean
                acc = torch.mean(correct_counts.type(torch.FloatTensor))

                # Compute total accuracy in the whole batch and add to valid_acc
                valid_acc += acc.item() * inputs.size(0)

            
        # Find average training loss and training accuracy
        avg_train_loss = train_loss/len(train_data) 
        avg_train_acc = train_acc/len(train_data) 

        # Find average training loss and training accuracy
        avg_valid_loss = valid_loss/len(val_data)  
        avg_valid_acc = valid_acc/len(val_data) 

        history.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc])
                
    
        print("Epoch : {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation : Loss : {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(epoch+1, avg_train_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))
        
            
    return model, history

In [None]:
model,history = fit(model, criterion,optimizer,10)

Epoch: 1/10


In [None]:
torch.save(model.state_dict(), 'ViT.pth')


In [None]:
history = np.array(history)
plt.plot(history[:,0:2])
plt.legend(['Training Loss', 'Val Loss'])
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.ylim(0,1)
plt.show()

In [None]:
plt.plot(history[:,2:4])
plt.legend(['Training Accuracy', 'Val Accuracy'])
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim(0,1)
plt.show()

In [None]:
idx_to_class = {v: k for k, v in train_data.class_to_idx.items()}
print(idx_to_class)


def predict(model, test_image_name):
    
    test_image = Image.open(test_image_name).convert('RGB')
    print(np.shape(test_image))
    plt.imshow(test_image)
 
    transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor() 
          ])
    img_normalized = transform(test_image).float()

    if torch.cuda.is_available():
        test_image_tensor = img_normalized.view(1, 3, 224, 224).cuda()
    else:
        test_image_tensor = img_normalized.view(1, 3, 224, 224)
    
    with torch.no_grad():
        model.eval()
        # Model outputs log probabilities
        out = model(test_image_tensor)
        prob = torch.exp(out)
        prob_, class_ = prob.topk(3, dim=1)
        class_ = class_.cpu().numpy()
        for i in range(3):
            print("Predcition", i+1, ":", idx_to_class[class_[0][i]], ", Score: ", prob_.cpu().numpy()[0][i])