In [1]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
!pip install einops
import einops
import torchvision
import torchvision.transforms as transforms
import cv2
import numpy as np
from torchvision import datasets
from torch.utils.data.sampler import SubsetRandomSampler
import os
from pathlib import Path
import torch
from PIL import Image
from tqdm import tqdm
from torchmetrics import F1Score
import matplotlib.pyplot as plt
import copy

In [None]:
valid_per = 0.20
batch_size = 8
epochs = 5
device = "cuda"
learning_rate = 0.002
beta1 = 0.9
beta2 = 0.990
weight_decay = 0.3
embed_size =768
num_heads = 8
patch_size = 12
in_channels = 3
num_encoders = 8
num_class = 12

In [2]:
class Project(nn.Module):
    def __init__(self,patch_size:int,in_channels:int,embed_size:int,batch_size:int):
        super().__init__()
        self.patch_size = patch_size
        self.batch_size = batch_size 
        self.embed_size = embed_size # embed size is the size of linearly projected patch of image
        self.in_channels = in_channels # channel size of image, 1 for grayscale, 3 for colored image

        self.linear = nn.Linear(self.in_channels*self.patch_size**2,self.embed_size)
        self.class_token = nn.Parameter(torch.randn(self.batch_size,1,embed_size))
        self.distill_token = nn.Parameter(torch.randn(self.batch_size,1,embed_size))
        self.position_embed = nn.Parameter(torch.randn(self.patch_size**2+2,self.embed_size))

    def forward(self,x:int): # num_batch x in_channel x width x height -> num_bahch x  num_patch x embed_size
        out = einops.rearrange(x,"b c (h px) (w py) -> b (h w) (c px py)",px = self.patch_size, py = self.patch_size)
        out = self.linear(out)
        out = torch.cat([out,self.class_token,self.distill_token],dim = 1)
        out = out + self.position_embed
        return out


In [3]:
class DotProductAttention(nn.Module):
    def __init__(self,):
        super().__init__()
        self.softmax = nn.Softmax(dim = 1)
    def forward(self,query,key,value):
        
        sdp = torch.matmul(self.softmax(torch.matmul(query,key)/key.size(dim = 2)**(1/2)),value)
        return sdp

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self,embed_size,num_heads,dropout = 0.1):
        super().__init__()
        self.num_heads = num_heads
        self.embed_size = embed_size
        self.DotProductAttention = DotProductAttention()

        self.key = nn.Linear(self.embed_size,self.num_heads* self.embed_size,bias = False)
        self.query = nn.Linear(self.embed_size,self.num_heads* self.embed_size,bias = False)
        self.value = nn.Linear(self.embed_size,self.num_heads* self.embed_size,bias = False)
        
        self.linear = nn.Linear(self.num_heads*self.embed_size,embed_size,bias = False)
        self.layer_norm = nn.LayerNorm(self.embed_size, eps=1e-6)
        self.dropout = nn.Dropout(dropout)
    def forward(self,embed):
        batch_size = embed.size(0)
        query = self.query(embed)
        key = einops.rearrange(self.query(embed),"b n e ->b e n")
        value = self.value(embed)
        sdp = self.DotProductAttention(query,key,value)
        return self.linear(sdp)
        


In [5]:
class EncoderBlock(nn.Module):
    def __init__(self,embed_size,num_heads,dropout = 0.1):
        super().__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.dropout =dropout
        self.mha = MultiHeadAttention(768,8)
        self.Linear1 = nn.Linear(self.embed_size,self.embed_size*4,bias=False)
        self.Linear2 = nn.Linear(self.embed_size*4,self.embed_size,bias = False)
        self.gelu = nn.GELU()
        self.layer_norm1 = nn.LayerNorm(self.embed_size,eps = 1e-6)
        self.layer_norm2 = nn.LayerNorm(self.embed_size,eps = 1e-6)

    def forward(self,embed:torch.Tensor):
        embed = embed + self.mha(self.layer_norm1(embed))
        embed = embed+ self.Linear2(self.gelu(self.Linear1(self.layer_norm2(embed))))
        return embed


In [6]:
class Transformer(nn.Module):
    def __init__(self,embed_size,num_heads,patch_size,in_channels,batch_size,num_encoders,num_class,device):
        super().__init__()
        self.device = device
        self.num_heads = num_heads
        self.embed_size = embed_size
        self.patch_size =patch_size
        self.in_channels = in_channels
        self.batch_size = batch_size
        self.num_encoders = num_encoders
        self.num_class = num_class
        self.proj = Project(self.patch_size,self.in_channels,self.embed_size,self.batch_size)

        self.tiny_block = [EncoderBlock(self.embed_size,self.num_heads) for i in range(self.num_encoders)]
        self.block_seq = nn.Sequential(*self.tiny_block)
        
        self.linear1 = nn.Linear(self.embed_size,self.num_class*4)
        self.linear2 = nn.Linear(self.num_class*4,self.num_class)
    def num_of_parameters(self,):

        return sum(p.numel() for p in self.parameters())
    
    def forward(self,img):
        out = self.proj(img)
        out = self.block_seq(out)
        out = self.linear1(torch.squeeze(torch.index_select(out,1,torch.tensor([self.patch_size**2,self.patch_size**2+1]).to(self.device))))
        out = self.linear2(out)

        return out

In [8]:
transformer = Transformer(embed_size,num_heads,patch_size,in_channels,batch_size,num_encoders,num_class,device).to(device)

In [9]:
class Custom_Dataset():

    def __init__(self, directory):
        self.path = Path(directory)
        Path.ls = lambda x: list(x.iterdir())
        try:
            files = os.listdir(directory)
            print(files)
        except:
            print("wrong path")
        self.x = [torch.tensor(np.transpose(np.array(Image.open(img).resize((144,144)))[:, :, :3], (2, 0, 1))).type(
            torch.FloatTensor) for img in (self.path/files[0]).ls()]
        self.x = torch.stack(self.x)/255
        self.y = torch.tensor([0]*len((self.path/files[0]).ls()))
        
        for i in range(len(files)-1):
            try:
                self.x2 = [torch.tensor(np.transpose(np.array(Image.open(img).resize((144,144)))[:, :, :3], (2, 0, 1))).type(
                torch.FloatTensor) for img in (self.path/files[i+1]).ls()]
            except:
                return 
            self.x2 = torch.stack(self.x2)/255
            self.x = torch.cat((self.x, self.x2), 0)
            self.y = torch.cat((self.y, torch.tensor(
                [i+1]*len((self.path/files[i+1]).ls()))))
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, i):
        return self.x[i], self.y[i]
    

In [10]:
dataset = Custom_Dataset("../input/animal-image-classification-dataset/Animal Image Dataset") # https://www.kaggle.com/datasets/gpiosenka/good-guysbad-guys-image-data-set
class_names = {0:"spider",1:"horse",2:"butterfly",3:"hen",4:"elephant",5:"sheep",6:"dogs",7:"cow",8:"panda",9:"monkey",10:"squirrel",11:"cats"}

In [11]:
def data_loaders(dataset,batch_size,train_per,valid_per):

    indices = torch.randperm(len(dataset))
    split_1 = int(np.floor((train_per)*(len(dataset))))
    split_2 = int(np.floor((train_per+valid_per)*(len(dataset))))
    t_idx, v_idx,test_idx = indices[:split_1], indices[split_1:split_2],indices[split_2:]
    train_sampler = SubsetRandomSampler(t_idx)
    val_sampler = SubsetRandomSampler(v_idx)
    test_sampler = SubsetRandomSampler(test_idx)
    trainloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, num_workers=2, drop_last=True, sampler=train_sampler)
    validloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, num_workers=2, drop_last=True, sampler=val_sampler)
    testloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, num_workers=2, drop_last=True, sampler=test_sampler)
    return trainloader, validloader,testloader



In [None]:
train_loader,val_loader,test_loader = data_loaders(dataset,8,0.65,0.20)
batch = next(iter(train_loader))

In [None]:
def display_examples(class_names, images, labels):
    """
        Display 8 images from the images array with its corresponding labels
    """
    
    fig = plt.figure(figsize=(10,6))
    fig.suptitle("Some examples of images of the dataset", fontsize=16)
    for i in range(8):
        plt.subplot(2,4,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(np.transpose(images[i],(1,2,0)), cmap=plt.cm.binary)
        plt.xlabel(class_names[int(labels[i])])
    plt.show()

In [None]:
display_examples(class_names,batch[0][:],batch[1][:])

In [None]:
def test(device,testloader, model, criterion,batch_size):
    model.eval()
    testloss = 0
    correct = 0
    i = -1
    with torch.no_grad():
        for data, target in tqdm(testloader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            testloss += loss.item()
            #_, predicted = torch.max(output, 1)
            _,predicted = torch.max(nn.Softmax()(output),dim = 1)
            correct += (predicted == target).sum().item()
            i += 1

    return correct/(i+1)*batch_size

In [None]:
def test_transformer(device,testloader, model, criterion,batch_size):
    model.eval()
    testloss = 0
    correct = 0
    i = -1
    with torch.no_grad():
        for data, target in tqdm(testloader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            class_embed_out,distill_embed_out = torch.squeeze(torch.index_select(output,1,torch.tensor(0).to(device))).to(device), torch.squeeze(torch.index_select(output,1,torch.tensor(1).to(device))).to(device)
            loss = criterion(output, target)
            testloss += loss.item()
            #_, predicted = torch.max(output, 1)
            _,predicted = torch.max(nn.Softmax()(class_embed_out),dim = 1)
            correct += (predicted == target).sum().item()
            i += 1

    return correct/(i+1)*batch_size

In [None]:
def train_baseline(trainloader, validloader, model, optimizer, criterion,epochs,f1,batch_size):
    device = "cuda"
    valid_loss_min = np.Inf
    model = model.to(device)
    for i in range(epochs):
        print("Epoch - {} Started".format(i+1))

        train_loss = 0.0
        valid_loss = 0.0
        train_score = 0.0
        val_score = 0.0
        model.train()
        for data, target in tqdm(trainloader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()*data.size(0)
            train_score = train_score +f1(output,target)
        model.eval()
        for data, target in validloader:
            data, target = data.to(device), target.to(device)
            with torch.no_grad():
                output = model(data)
            loss = criterion(output, target)
            valid_loss += loss.item()*data.size(0)
            val_score = val_score + f1(output,target)
        train_loss = train_loss/len(trainloader.sampler)
        valid_loss = valid_loss/len(validloader.sampler)
        train_score = batch_size*train_score/len(trainloader.sampler)
        val_score = batch_size*val_score/len(validloader.sampler)
        print(f"F1 Score for train: {train_score}, F1 Score for validation: {val_score} ")
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            i+1, train_loss, valid_loss))

        if valid_loss <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
                valid_loss_min,
                valid_loss))
            best_model_wts = copy.deepcopy(model.state_dict())
            
            valid_loss_min = valid_loss
    torch.save(best_model_wts, 'model.pt')

In [None]:
resnet50 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True).to("cuda")
for param in resnet50.parameters():
    param.requires_grad = False
resnet50.fc = nn.Linear(2048,12)
resnet50 = resnet50.to("cuda")

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet50.parameters(),lr = learning_rate,betas = (beta1,beta2),weight_decay=weight_decay)
f1 = F1Score(num_class).to(device)

In [None]:
train_baseline(train_loader,val_loader,resnet50,optimizer,criterion,1,f1,8)

In [None]:
test(device,test_loader,resnet50,criterion,batch_size)

In [None]:
def train_transformer(trainloader, validloader, model, optimizer, criterion,epochs,f1,batch_size):
    device = "cuda"
    valid_loss_min = np.Inf
    model = model.to(device)
    for i in range(epochs):
        print("Epoch - {} Started".format(i+1))

        train_loss = 0.0
        valid_loss = 0.0
        train_score = 0.0
        val_score = 0.0
        model.train()
        for data, target in tqdm(trainloader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            class_embed_out,distill_embed_out = torch.squeeze(torch.index_select(output,1,torch.tensor(0).to(device))).to(device), torch.squeeze(torch.index_select(output,1,torch.tensor(1).to(device))).to(device)
            loss = criterion(class_embed_out, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()*data.size(0)
            train_score = train_score +f1(class_embed_out,target)
        model.eval()
        for data, target in validloader:
            data, target = data.to(device), target.to(device)
            with torch.no_grad():
                output = model(data)
                class_embed_out,distill_embed_out = torch.squeeze(torch.index_select(output,1,torch.tensor(0).to(device))).to(device), torch.squeeze(torch.index_select(output,1,torch.tensor(1).to(device))).to(device)

            loss = criterion(class_embed_out, target)
            valid_loss += loss.item()*data.size(0)
            val_score = val_score + f1(class_embed_out,target)
        train_loss = train_loss/len(trainloader.sampler)
        valid_loss = valid_loss/len(validloader.sampler)
        train_score = batch_size*train_score/len(trainloader.sampler)
        val_score = batch_size*val_score/len(validloader.sampler)
        print(f"F1 Score for train: {train_score}, F1 Score for validation: {val_score} ")
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            i+1, train_loss, valid_loss))

        if valid_loss <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
                valid_loss_min,
                valid_loss))
            
            best_model_wts = copy.deepcopy(model.state_dict())
            
            valid_loss_min = valid_loss
    torch.save(best_model_wts, 'model.pt')

In [None]:
train_transformer(train_loader,val_loader,transformer,optimizer,criterion,1,f1,8)

In [None]:
test(device,test_loader,transformer,criterion,batch_size)

In [None]:
def hard_distill_loss(criterion,label,teacher_out,distill_out,class_embed_out,temprature = 1):
    distill_embed_probs = F.softmax(distill_out/temprature,dim = 1)
    class_embed_probs = F.softmax(class_embed_out/temprature,dim = 1)
    teacher_desicion = torch.argmax(teacher_out,dim = 1)
    teacher_loss = criterion(distill_embed_probs,teacher_desicion)
    gt_loss = criterion(class_embed_probs,label)
    hardDistillGlobal = (1/2)*teacher_loss + (1/2)*gt_loss
    return hardDistillGlobal

In [None]:
def train_distill(trainloader, validloader, model_student,model_teacher, optimizer,criterion,epochs,f1,batch_size):
    device = "cuda"
    valid_loss_min = np.Inf
    model_student = model_student.to(device)
    model_teacher = model_teacher.to(device)
    model_teacher.eval()
    for i in range(epochs):
        print("Epoch - {} Started".format(i+1))

        train_loss = 0.0
        valid_loss = 0.0
        train_score = 0.0
        val_score = 0.0
        model_student.train()
        for data, target in tqdm(trainloader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            student_output = model_student(data)
            teacher_output = model_teacher(data)
            class_embed_out,distill_embed_out = torch.squeeze(torch.index_select(student_output,1,torch.tensor(0).to(device))).to(device), torch.squeeze(torch.index_select(student_output,1,torch.tensor(1).to(device))).to(device)
            loss = hard_distill_loss(criterion,target,teacher_output,distill_embed_out,class_embed_out)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()*data.size(0)
            train_score = train_score +f1(class_embed_out,target)
        model_student.eval()
        for data, target in validloader:
            data, target = data.to(device), target.to(device)
            with torch.no_grad():
                student_output = model_student(data)
                teacher_output = model_teacher(data)
            class_embed_out,distill_embed_out = torch.squeeze(torch.index_select(student_output,1,torch.tensor(0).to(device))).to(device), torch.squeeze(torch.index_select(student_output,1,torch.tensor(1).to(device))).to(device)
            loss = hard_distill_loss(criterion,target,teacher_output,distill_embed_out,class_embed_out)
            valid_loss += loss.item()*data.size(0)
            val_score = val_score + f1(class_embed_out,target)
        train_loss = train_loss/len(trainloader.sampler)
        valid_loss = valid_loss/len(validloader.sampler)
        train_score = batch_size*train_score/len(trainloader.sampler)
        val_score = batch_size*val_score/len(validloader.sampler)
        print(f"F1 Score for train: {train_score}, F1 Score for validation: {val_score} ")
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            i+1, train_loss, valid_loss))

        if valid_loss <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
                valid_loss_min,
                valid_loss))
            best_model_wts = copy.deepcopy(model_student.state_dict())
            
            valid_loss_min = valid_loss
    torch.save(best_model_wts, 'model.pt')

In [None]:
train_distill(train_loader,val_loader,transformer,resnet50,optimizer,criterion,10,f1,8)