In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from os.path import join as opj
import cv2

import torch.nn as nn
import torchvision
from torchvision import models,transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import accuracy_score, roc_curve
from sklearn.metrics.pairwise import euclidean_distances
from scipy.spatial.distance import pdist

np.random.seed(123)
torch.cuda.manual_seed_all(5)
torch.backends.cudnn.deterministic=True
torch.manual_seed(123)
BATCH_SIZE = 64
EMBEDDING_SIZE = 512

In [2]:
class CelebADataset(Dataset):
    
    def __init__(self, base_path, transform=None):
        self.images_path = opj(base_path, "images")
        self.names = []
        self.labels = []
        identity_list = opj(base_path, "identity_CelebA.txt")
        
        for l in open(identity_list):
            name, label = l.split(sep=" ")
            self.names.append(name.replace('jpg','png'))
            self.labels.append(int(label.replace("\n","")))
            
        self.transform = transform
        self.labels = torch.Tensor(self.labels).long()
        
    
    def __len__(self):
        return len(self.labels)
    
    
    def __getitem__(self,idx):
        
        label = self.labels[idx]
        img = cv2.imread(opj(self.images_path,self.names[idx]))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        #img = cv2.resize(img, (256,256))
        if self.transform is not None:
            img = self.transform(img)
        
        return {"image":img, "label":label}

In [3]:
g = torch.Generator()
g.manual_seed(123)

<torch._C.Generator at 0x223be8ef7b0>

In [4]:
transform = transforms.Compose([transforms.ToTensor()])

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
    
device

'cuda'

In [5]:
dataset = CelebADataset(".\\CelebA", transform=transform)

train_size = int(0.8 * len(dataset))
test_size = int(0.2 * len(dataset)) + 1

train_dataset, test_dataset = torch.utils.data.random_split(dataset, (train_size,test_size))
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE, shuffle=True,generator=g)
test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE, shuffle=False,generator=g)

In [6]:
acos_eps = 1e-6

def ArcfaceLoss(y_pred,y_true):
    
    denominators = torch.sum(torch.exp(s*y_pred), dim=1)
    denominators = denominators - torch.exp(torch.sum(y_true * y_pred * s, dim=1))
    cos_thetas = torch.sum(y_true * y_pred, dim=1)
    cos_thetas = torch.clamp(cos_thetas,-1+acos_eps, 1-acos_eps)
    thetas = torch.acos(cos_thetas)
    new_cos_thetas = s * torch.cos(thetas + m)
    numerators = torch.exp(new_cos_thetas)
    denominators = denominators + numerators
    loss = -torch.mean(torch.log(numerators / denominators))
    

    return loss

In [7]:
class Backbone(nn.Module):

    def __init__(self):
        super(Backbone,self).__init__()
        self.conv1 = nn.Conv2d(3, 8, 3,1,1)
        self.bn8 = nn.BatchNorm2d(8)
        self.bn16 = nn.BatchNorm2d(16)
        self.bn32 = nn.BatchNorm2d(32)
        self.bn64 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(8, 16, 3,1,1)
        self.conv3 = nn.Conv2d(16,32,3,1,1)
        self.conv4 = nn.Conv2d(32,64,3,1,1)
        self.fc1 = nn.Linear(9152, 2048)
        self.fc2 = nn.Linear(2048, EMBEDDING_SIZE)

    def forward(self, x):
        x = self.pool(F.leaky_relu(self.conv1(x)))
        x = self.bn8(x)
        x = self.pool(F.leaky_relu(self.conv2(x)))
        x = self.bn16(x)
        x = self.pool(F.leaky_relu(self.conv3(x)))
        x = self.bn32(x)
        x = self.pool(F.leaky_relu(self.conv4(x)))
        x = self.bn64(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        emb = F.normalize(x, p=2.0, dim=1)
        
        return emb

In [8]:
class Backbone(nn.Module):

    def __init__(self):
        super(Backbone,self).__init__()
        self.resnet = torchvision.models.resnet18(weights=None)
        self.resnet.fc = nn.Sequential()

    def forward(self, x):
        x = F.relu(self.resnet(x))
        emb = F.normalize(x, p=2.0, dim=1)
        
        return emb

In [9]:
class ArcfaceNN(Backbone):
    def __init__(self,n_classes):
        super(ArcfaceNN,self).__init__()
        self.n_classes = n_classes
        self.fc = nn.Linear(EMBEDDING_SIZE, self.n_classes)
        
    def forward(self,x):
        emb = super(ArcfaceNN,self).forward(x)
        
        for W in self.fc.parameters():
            W = F.normalize(W, p=2, dim=0)
        
        x = self.fc(emb)
        
        
        return x
    
    def make_emb(self,x):
        return super(ArcfaceNN,self).forward(x)

In [10]:
n_classes = int(dataset.labels.max()) + 1
ARCFACE_EPOCHS = 50
s = 64.0
m = 0.5
arcface_model = ArcfaceNN(n_classes).to(device).train()
lr = 0.0001
criterion = ArcfaceLoss
optimizer = optim.Adam(arcface_model.parameters(), lr=lr)

In [None]:
%%time

epoch = 0
while(epoch <= ARCFACE_EPOCHS):
    epoch_loss = 0
    for i, data in enumerate(train_loader):
        if i % 500 == 499:
            print(f"batch {i} of {len(train_loader)}")
        image, label = data['image'],data['label']
        image = image.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        
        y_pred = arcface_model.forward(image)
        label = F.one_hot(label,num_classes=n_classes)
        
        loss = criterion(y_pred,label) 
        epoch_loss +=loss
        loss.backward()
        optimizer.step()
        
    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    torch.save(arcface_model, f".\\arcface_new_{epoch_loss}.pt")
    print(f'[{epoch}] loss: {epoch_loss}')
    epoch += 1

batch 499 of 2533
batch 999 of 2533
batch 1499 of 2533
batch 1999 of 2533
batch 2499 of 2533
[0] loss: 26231.052734375
batch 499 of 2533
batch 999 of 2533
batch 1499 of 2533
batch 1999 of 2533
batch 2499 of 2533
[1] loss: 23529.451171875
batch 499 of 2533
batch 999 of 2533
batch 1499 of 2533
batch 1999 of 2533
batch 2499 of 2533
[2] loss: 21054.9921875
batch 499 of 2533
batch 999 of 2533
batch 1499 of 2533
batch 1999 of 2533
batch 2499 of 2533
[3] loss: 17742.369140625
batch 499 of 2533
batch 999 of 2533
batch 1499 of 2533
batch 1999 of 2533
batch 2499 of 2533
[4] loss: 14419.4189453125
batch 499 of 2533
batch 999 of 2533
batch 1499 of 2533
batch 1999 of 2533
batch 2499 of 2533
[5] loss: 11426.54296875
batch 499 of 2533
batch 999 of 2533
batch 1499 of 2533
batch 1999 of 2533
batch 2499 of 2533
[6] loss: 8864.9814453125
batch 499 of 2533
batch 999 of 2533
batch 1499 of 2533
batch 1999 of 2533
batch 2499 of 2533
[7] loss: 6798.6142578125
batch 499 of 2533
batch 999 of 2533
batch 1499 of 

In [None]:
arcface_model.fc = nn.Sequential()

In [None]:
device = 'cpu'

In [None]:
del train_dataset
del train_loader
del g

In [None]:
labels = np.array([sample['label'] for sample in test_dataset])

In [None]:
n_samples = len(labels)

In [None]:
embs = []
arcface_model = arcface_model.to(device).eval()

for data in test_loader:
        image, labebl = data['image'],data['label']
        image = image.to(device)
        embs.extend(arcface_model.forward(image).detach().cpu().numpy())


In [None]:
diff_class = labels[:, np.newaxis] != labels[np.newaxis, :]
diff_class = torch.Tensor(diff_class).to(device).bool().flatten()

In [None]:
idc = torch.triu_indices(n_samples,n_samples,1,device=device)
idc = idc[0] * n_samples + idc[1]
diff_class = diff_class[idc]

In [None]:
#dists = euclidean_distances(embs,embs).flatten()
dists_np = pdist(embs).astype('float16')
dists = torch.Tensor(dists_np).to(device)

In [None]:
del embs
del test_loader
del test_dataset
del labels
del idc

In [None]:
diff_class_np = diff_class.cpu().numpy().astype('bool')

In [None]:
np.sum(~diff_class_np)

In [None]:
plt.figure(figsize=(10,4))
plt.xlim(0.0,2)
plt.hist(dists_np[~diff_class_np],alpha = 0.5,color="blue",label="Positive pairs",range=(0.01,2))
plt.hist(dists_np[diff_class_np][:np.sum(~diff_class_np)],alpha = 0.5,color="red",label="Negative pairs") #what?
plt.legend()

In [None]:
sort_index = np.argsort(dists_np)
diff_class_np = diff_class_np[sort_index]
dists_np = dists_np[sort_index]
del sort_index

In [None]:
#TP = np.sum(np.logical_and(preds == 1, y_true == 1))
#TN = np.sum(np.logical_and(preds == 0, y_true == 0))
#FP = np.sum(np.logical_and(preds == 1, y_true == 0))
#FN = np.sum(np.logical_and(preds == 0, y_true == 1))
#sort_index = np.argsort(distances)
#y_true = y_true[sort_index]
#distances = distances[sort_index]

In [None]:
def roc_curve_my(y_true, distances):
    threshes = []
    fpr = []
    tpr = []
    
    
        
    start_i = 0
    end_i = 0
    
    preds = distances >= 0.0
    TP = np.sum(y_true == 1)
    TN = 0
    FP = len(y_true) - TP
    FN = 0
    
    TPFN = TP + FN
    FPTN = FP + TN

    
    tpr.append(TP / TPFN)
    fpr.append(FP/ FPTN)
    
    for thresh in np.arange(0.0,2.001,0.001):
        while(distances[end_i] < thresh):
            end_i +=1
            if(end_i >= len(distances) -1):
                end_i = len(distances) -1
                break
        
        if(start_i == end_i):
            continue
        slc_true = y_true[start_i:end_i]
        
        
        a_FN = np.sum(slc_true == 1)
        a_TN = np.sum(slc_true == 0)
        
        TP -= a_FN
        FP -= a_TN
        
        
        tpr.append(TP / TPFN)
        fpr.append(FP/ FPTN)
        threshes.append(thresh)
        
        start_i = end_i
        if(end_i >= len(distances) -1):
            end_i = len(distances) -1
    threshes = np.append(threshes,10.0)
    return fpr,tpr,threshes
        

In [None]:
%%time
mfpr,mtpr,thresh = roc_curve_my(diff_class_np,dists_np)

In [None]:
plt.figure(figsize=(7,5))
plt.xscale('log')
plt.yticks(np.arange(0,1,0.05))
plt.xlim(0.00005,1)
plt.plot(mfpr,mtpr)
plt.title("ROC curve")
plt.ylabel('True Positive Rate')  
plt.xlabel('False Positive Rate')  
plt.grid()
plt.show()  

In [None]:
closest_pt = np.argmin(np.linalg.norm(np.array([1.0, 0.0]) - np.vstack([mtpr,mfpr]).T,axis=1))

In [None]:
plt.figure(figsize=(7,5))
plt.xscale('log')
plt.yticks(np.arange(0,1,0.05))
plt.xlim(0.00006,1)
plt.plot(mfpr,mtpr)
plt.scatter(mfpr[closest_pt],mtpr[closest_pt],c="red")
plt.title("ROC curve")
plt.ylabel('True Positive Rate')  
plt.xlabel('False Positive Rate')  
plt.grid()
plt.show()  

In [None]:
best_thresh = thresh[closest_pt]
y_preds = dists_np > best_thresh

In [None]:
print("Accuracy score: ", accuracy_score(diff_class_np, y_preds))
print("Best thresh: ", best_thresh)

inp = torch.rand(1, 3, 122, 122).to(device)
traced_script_module = torch.jit.trace(arcface_model, inp)
traced_script_module.save("traced_arcface_model.pt")