In [None]:
import torch 
import torch.nn.functional as F

from torch import nn, optim 
from torch.utils.data import DataLoader
from torchvision import transforms as T, datasets
from torchvision.models import efficientnet_b1

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import plotly.express as px

from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
trainset = datasets.CIFAR10(root='/content/dataset', download=True, train = True, transform = transform)
testset = datasets.CIFAR10(root='/content/dataset', download=True, train = False, transform = transform)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
class ArcFace(nn.Module):
     """Implementation for "ArcFace: Additive Angular Margin Loss for Deep Face Recognition"
     """
     def __init__(self, feat_dim, num_class, margin_arc=0.4, margin_am=0.0, scale=32):
         super(ArcFace, self).__init__()
         self.weight = nn.Parameter(torch.Tensor(feat_dim, num_class))
         self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
         self.margin_arc = margin_arc
         self.margin_am = margin_am
         self.scale = scale
         self.cos_margin = math.cos(margin_arc)
         self.sin_margin = math.sin(margin_arc)
         self.min_cos_theta = math.cos(math.pi - margin_arc)

     def forward(self, feats, labels):
         kernel_norm = F.normalize(self.weight, dim=0)

         feats = F.normalize(feats)

         cos_theta = torch.mm(feats, kernel_norm)

         cos_theta = cos_theta.clamp(-1, 1)

         sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))

         cos_theta_m = cos_theta * self.cos_margin - sin_theta * self.sin_margin

         cos_theta_m = torch.where(cos_theta > self.min_cos_theta, cos_theta_m, cos_theta-self.margin_am)

         index = torch.zeros_like(cos_theta)

         index.scatter_(1, labels.data.view(-1, 1), 1)
         index = index.type(torch.bool)
         output = cos_theta * 1.0
         output[index] = cos_theta_m[index]
         output *= self.scale

         return output


class MNIST_Model(nn.Module):
    
    def __init__(self):
        super(MNIST_Model, self).__init__()

        self.conv1 = nn.Conv2d(1, 12, kernel_size=5)
        self.conv2 = nn.Conv2d(12, 24, kernel_size=4)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(384, 72)
        self.fc2 = nn.Linear(72, 3)
        self.arc_face = ArcFace(feat_dim=3, num_class=10)
        
    def forward(self,features,targets = None):
        
        x = F.relu(F.max_pool2d(self.conv1(features), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        _,c,h,w = x.shape
        x = x.view(-1, c*h*w)
        x = F.relu(self.fc1(x))
        x = F.normalize(self.fc2(x))
        
        if targets is not None:
            logits = self.arc_face(x,targets)
            return logits
        return x

In [None]:
class TrainModel():
    
    def __init__(self,criterion = None,optimizer = None,schedular = None,device = None):
        self.criterion = criterion
        self.optimizer = optimizer
        self.schedular = schedular
        self.device = device
        
    def accuracy(self,logits,labels):
        ps = torch.argmax(logits,dim = 1).detach().cpu().numpy()
        acc = accuracy_score(ps,labels.detach().cpu().numpy())
        return acc

    def get_dataloader(self,trainset,validset):
        trainloader = DataLoader(trainset,batch_size = 26, num_workers = 4, pin_memory = True)
        validloader = DataLoader(validset,batch_size = 4, num_workers = 4, pin_memory = True)
        return trainloader, validloader
        
    def train_batch_loop(self,model,trainloader,i):
        
        epoch_loss = 0.0
        epoch_acc = 0.0
        pbar_train = tqdm(trainloader, desc = "Epoch" + " [TRAIN] " + str(i+1))
        
        for t,data in enumerate(pbar_train):
            
            images,labels = data
            images = images.to(device)
            labels = labels.to(device)
            
            logits = model(images,labels)
            loss = self.criterion(logits,labels)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            epoch_loss += loss.item()
            epoch_acc += self.accuracy(logits,labels)
            
            pbar_train.set_postfix({'loss' : '%.6f' %float(epoch_loss/(t+1)), 'acc' : '%.6f' %float(epoch_acc/(t+1))})
            
        return epoch_loss / len(trainloader), epoch_acc / len(trainloader)
            
    
    def valid_batch_loop(self,model,validloader,i):
        
        epoch_loss = 0.0
        epoch_acc = 0.0
        pbar_valid = tqdm(validloader, desc = "Epoch" + " [VALID] " + str(i+1))
        
        for v,data in enumerate(pbar_valid):
            
            images,labels = data
            images = images.to(device)
            labels = labels.to(device)
            
            logits = model(images,labels)
            loss = self.criterion(logits,labels)
            
            epoch_loss += loss.item()
            epoch_acc += self.accuracy(logits,labels)
            
            pbar_valid.set_postfix({'loss' : '%.6f' %float(epoch_loss/(v+1)), 'acc' : '%.6f' %float(epoch_acc/(v+1))})
            
        return epoch_loss / len(validloader), epoch_acc / len(validloader)
            
    
    def run(self,model,trainset,validset,epochs):
    
        trainloader,validloader = self.get_dataloader(trainset,validset)
        
        for i in range(epochs):
            
            model.train()
            avg_train_loss, avg_train_acc = self.train_batch_loop(model,trainloader,i)

            torch.save(model, f'/content/model_{i}.pth')

            model.eval()
            avg_valid_loss, avg_valid_acc = self.valid_batch_loop(model,validloader,i)
            
        return model 

In [None]:
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.00075)

model = TrainModel(criterion, optimizer, device).run(model, trainset, testset, 3)

In [None]:
emb = []
y = []

testloader = DataLoader(testset,batch_size = 64)
with torch.no_grad():
    for images,labels in tqdm(testloader):
        
        images = images.to(device)
        embeddings = model(images)
        
        emb += [embeddings.detach().cpu()]
        y += [labels]
        
    embs = torch.cat(emb).cpu().numpy()
    y = torch.cat(y).cpu().numpy()

  0%|          | 0/157 [00:00<?, ?it/s]

In [None]:
tsne_df = pd.DataFrame(
    np.column_stack((embs, y)),
    columns = ["x","y","z","targets"]
)

fig = px.scatter_3d(tsne_df, x='x', y='y', z='z',
              color='targets')
fig.show()