In [1]:
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader

import os
import numpy as np
from PIL import Image
from tqdm import tqdm
from collections import defaultdict
from sklearn.metrics import f1_score, confusion_matrix, precision_score, recall_score
import pickle
import json

from transformers import CLIPProcessor, CLIPModel

In [2]:
N_CLASSES = 3000

In [None]:
model1 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model2 = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)

In [4]:
model1.fc = nn.Linear(in_features=512, out_features=N_CLASSES, bias=True)
model2.fc = nn.Linear(in_features=2048, out_features=N_CLASSES, bias=True)

In [5]:
def transform(img):
    img = np.array(img)
    width, height, _ = img.shape
    
    T = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        transforms.CenterCrop(min(width, height)),
        transforms.Resize((299, 299))
    ])
    
    return T(img)

In [6]:
class CelebaClassificationDataset(Dataset):
    def __init__(self, mode='train', class_mapping=None, prev_class_mapping=None):
        with open('identity_CelebA.txt') as f:
            label_mapping = dict([line.replace('\n', '').split(' ') for line in f.readlines()])
            self.img2class = {key: int(value) for key, value in label_mapping.items()}
            
        self.train_classes = sorted(list(set([self.img2class[_] for _ in os.listdir('train')])))
        self.val_classes   = sorted(list(set([self.img2class[_] for _ in os.listdir('val') if \
                                             self.img2class[_] not in self.train_classes])))
        self.test_classes  = sorted(list(set([self.img2class[_] for _ in os.listdir('test') if \
                                  (self.img2class[_] not in self.train_classes)\
                                              and (self.img2class[_] not in self.val_classes)])))
        
        self.all_classes = self.train_classes + self.val_classes + self.test_classes
        
        self.class_mapping = {idx_old:idx_new for idx_new, idx_old in enumerate(self.all_classes)}
            
        self.images = sorted(os.listdir(mode))
        self.img2class = {key: self.class_mapping.get(value) for key, value in self.img2class.items()}
        self.images = [f'{mode}/{_}' for _ in self.images]
        self.img2class = {f'{mode}/{key}': value for key, value in self.img2class.items()}
        
        if prev_class_mapping:
            prev_class_mapping = json.loads(open(prev_class_mapping, 'r').read())
            for key, value in prev_class_mapping.items():
                if value != N_CLASSES + 1:
                    self.img2class[key.replace('train/', f'{mode}/')] = value

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(self.images[idx])
        img = transform(img)
        label = self.img2class[self.images[idx]]
        return img, label

In [7]:
with open('img2class.json', 'w') as f:
    f.write(json.dumps(train_ds.img2class))

In [8]:
batch_size = 16

train_ds = CelebaClassificationDataset('train', prev_class_mapping='img2class.json')
val_ds   = CelebaClassificationDataset('val'  , prev_class_mapping='img2class.json')
test_ds  = CelebaClassificationDataset('test' , prev_class_mapping='img2class.json')

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_dl   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)
test_dl  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

In [9]:
model1.load_state_dict(torch.load('weights/resnet_49'))
model2.load_state_dict(torch.load('weights/inception_49'))

In [10]:
device = 'cuda'

In [11]:
if not os.path.exists('metrics'):
    os.mkdir('metrics')

def scores(preds_labels):
    accuracy  = np.mean([_[0] == _[1] for _ in preds_labels])
    fscore    = f1_score([_[1] for _ in preds_labels], [_[0] for _ in preds_labels], average='macro')
    conf_mat  = confusion_matrix([_[1] for _ in preds_labels], [_[0] for _ in preds_labels])
    precision = precision_score([_[1] for _ in preds_labels], [_[0] for _ in preds_labels], average='macro')
    recall    = recall_score([_[1] for _ in preds_labels], [_[0] for _ in preds_labels], average='macro')
    
    return {'accuracy': accuracy,
            'fscore': fscore,
            'conf_mat': conf_mat,
            'precision': precision,
            'recall': recall,
           }

def evaluate(model):
    model = model.to(device)
    preds  = []
    golden = []
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(val_dl):
            images = images.to(device)
            labels = labels.to(device)
            pred = model(images)

            preds  += [torch.argmax(_).item() for _ in pred]
            golden += [_.item() for _ in labels]
    model.train()
    
    preds_labels = [(pred, label) for pred, label in zip(preds, golden) if label <= N_CLASSES]
    metrics = scores(preds_labels)
    return metrics

def class_train_loop(model, filename, chkpt=None, N_EPOCHS=50):
    bad_epoch_counter = 0
    running_loss = []
    running_metrics = []
    last_metrics = {
        'accuracy': 0,
        'fscore': 0,
        'conf_mat': 0,
        'precision': 0,
        'recall': 0,
    }
    
    if chkpt:
        model.load_state_dict(torch.load(chkpt))
    
    model = model.cuda()
    criterion = nn.CrossEntropyLoss()
    opt = torch.optim.Adam(model.parameters(), lr=3e-4)
    
    for epoch in range(N_EPOCHS):
        p_bar = tqdm(train_dl)
        for images, labels in p_bar:
            images = images.to(device)
            labels = labels.to(device)
            
            opt.zero_grad()
            preds = model(images)
            if type(preds) == models.inception.InceptionOutputs:
                loss = criterion(preds.logits, labels)
            else:
                loss = criterion(preds, labels)
            
            running_loss += [loss.item()]
            p_bar.set_description(f'current_loss: {loss.item()}')
            loss.backward()
            opt.step()

        metrics = evaluate(model)
        running_metrics += [metrics]
        
        if metrics['accuracy'] >= max([_['accuracy'] for _ in running_metrics]):
            torch.save(model.state_dict(), f'weights/{filename}')
            
        if metrics['accuracy'] < last_metrics['accuracy']:
            bad_epoch_counter += 1
        elif (metrics['accuracy'] - last_metrics['accuracy']) < 0.001 * metrics['accuracy']:
            bad_epoch_counter += 1
        else:
            bad_epoch_counter = 0
            
        if bad_epoch_counter == 5:
            with open(f'metrics/{filename}_{epoch}.pickle', 'wb') as f:
                pickle.dump({'running_metrics': running_metrics, 'running_loss': running_loss}, f)
                
            return {'running_metrics': running_metrics, 'running_loss': running_loss}
        
        last_metrics = metrics
        print(f"{epoch} epoch::{metrics['accuracy']}")
        
    with open(f'metrics/{filename}_{epoch}.pickle', 'wb') as f:
        pickle.dump({'running_metrics': running_metrics, 'running_loss': running_loss}, f)
    return {'running_metrics': running_metrics, 'running_loss': running_loss}

In [12]:
def evaluate2(model):
    model = model.to(device)
    preds  = []
    golden = []
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(val_dl):
            images = images.to(device)
            labels = labels.to(device)
            pred = model(images)

            preds  += [(torch.argmax(_).item(), torch.max(_).item()) for _ in pred]
            golden += [_.item() for _ in labels]
    model.train()
    
    preds_labels = [(pred, label) for pred, label in zip(preds, golden) if (label != N_CLASSES + 1)]
    #metrics = scores(preds_labels)
    return preds_labels 

In [17]:
class_train_loop(model1, 'resnet3', chkpt = 'weights/resnet2', N_EPOCHS=20)

In [18]:
class_train_loop(model2, 'inception3', chkpt = 'weights/inception2', N_EPOCHS=20)

In [20]:
class CelebaEmbDataset(Dataset):
    def __init__(self, mode='train', class_mapping=None, prev_class_mapping=None):
        with open('identity_CelebA.txt') as f:
            label_mapping = dict([line.replace('\n', '').split(' ') for line in f.readlines()])
            self.img2class = {key: int(value) for key, value in label_mapping.items()}
            
        self.train_classes = sorted(list(set([self.img2class[_] for _ in os.listdir('train')])))
        self.val_classes   = sorted(list(set([self.img2class[_] for _ in os.listdir('val') if \
                                             self.img2class[_] not in self.train_classes])))
        self.test_classes  = sorted(list(set([self.img2class[_] for _ in os.listdir('test') if \
                                  (self.img2class[_] not in self.train_classes)\
                                              and (self.img2class[_] not in self.val_classes)])))
        
        self.all_classes = self.train_classes + self.val_classes + self.test_classes
        
        self.class_mapping = {idx_old:idx_new for idx_new, idx_old in enumerate(self.all_classes)}
            
        self.images = sorted(os.listdir(mode))
        self.img2class = {key: self.class_mapping.get(value) for key, value in self.img2class.items()}
        self.images = [f'{mode}/{_}' for _ in self.images]
        self.img2class = {f'{mode}/{key}': value for key, value in self.img2class.items()}
        
        if prev_class_mapping:
            prev_class_mapping = json.loads(open(prev_class_mapping, 'r').read())
            for key, value in prev_class_mapping.items():
                if value != N_CLASSES + 1:
                    self.img2class[key.replace('train/', f'{mode}/')] = value
        
        
        self.dd = defaultdict(list)
        for img in self.images:
            self.dd[self.img2class[img]] += [img]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        anchor = Image.open(self.images[idx])
        anchor = transform(anchor)
        label = self.img2class[self.images[idx]]
        
        positive = np.random.choice(self.dd[label])
        positive = transform(Image.open(positive))
        
        negative = np.random.choice(self.dd[np.random.choice([idx for idx in range(300) if idx != label])])
        negative = transform(Image.open(negative))
        
        return {'anchor': anchor, 'positive': positive, 'negative': negative}

In [21]:
batch_size = 16

train_emb_ds = CelebaEmbDataset('train', prev_class_mapping='img2class.json')
val_emb_ds   = CelebaEmbDataset('val',   prev_class_mapping='img2class.json')
test_emb_ds  = CelebaEmbDataset('test',  prev_class_mapping='img2class.json')

train_emb_dl = DataLoader(train_emb_ds, batch_size=batch_size, shuffle=True)
val_emb_dl   = DataLoader(val_emb_ds,   batch_size=batch_size, shuffle=False)
test_emb_dl  = DataLoader(test_emb_ds,  batch_size=batch_size, shuffle=False)

In [22]:
model1.fc         = nn.Identity(512)
model2.fc         = nn.Identity(2048)

In [23]:
from copy import deepcopy

def calculate_accuracy(val_embs, train_dd, threshold=0.5):
    train_dd_temp = deepcopy(train_dd)
    
    def norm(x):
        return x / torch.sqrt(torch.sum(x ** 2))
    
    cs = nn.CosineSimilarity(dim=1)
    
    labels = torch.tensor(list(train_dd_temp.keys()))
    embs   = torch.vstack(list(train_dd_temp.values()))

    final_preds = []

    for val_emb, val_face in tqdm(val_embs):
        cos_sim = cs(embs, val_emb)
        score, label = torch.max(cos_sim).item(), labels[torch.argmax(cos_sim).item()].item()
        
        if score > threshold:
            final_preds += [{'score': score, 'pred': label, 'true': val_face}]
        else:
            if val_face in train_dd_temp:
                final_preds += [{'score': score, 'pred': -1, 'true': val_face}]
            else:
                train_dd_temp[val_face] = val_emb
                embs   = torch.vstack(list(train_dd_temp.values()))
                labels = torch.tensor(list(train_dd_temp.keys()))
                final_preds += [{'score': score, 'pred': val_face, 'true': val_face}]
                
    return np.mean([_['pred'] == _['true'] for _ in final_preds]), [(_['pred'], _['true']) for _ in final_preds]

def calculate_embs(model, dl):
    preds  = []
    golden = []
    
    with torch.no_grad():
        for images, labels in tqdm(dl):
            images = images.to(device)
            labels = labels.to(device)
            pred = model(images)

            preds  += [_.detach() for _ in pred]
            golden += [_.item() for _ in labels]
    
    preds_labels = list(zip(preds, golden))
    return preds_labels

def evaluate_emb(model, threshold=0.7, mode='val'):
    model = model.to(device)
    model.eval()
    
    def norm(x):
        return x / torch.sqrt(torch.sum(x ** 2))
    
    train_embs = calculate_embs(model, train_dl)
    val_embs   = calculate_embs(model, val_dl)
    test_embs  = calculate_embs(model, test_dl)

    train_dd = defaultdict(list)
    
    for emb, face in train_embs:
        train_dd[face] += [emb]
        
    for face, embs in train_dd.items():
        train_dd[face] = torch.mean(torch.vstack([norm(emb) for emb in embs]), dim=0)

    return train_dd, val_embs, test_embs

In [24]:
from pytorch_metric_learning import losses

def emb_train_loop(model, filename, chkpt=None, N_EPOCHS=15, lossname='triplet'):
    running_loss = []
    
    def get_preds(preds):
        if type(preds) == models.inception.InceptionOutputs:
            return preds.logits
        else:
            return preds
    
    if chkpt:
        model.load_state_dict(torch.load(chkpt))
    
    model = model.to(device)
    with torch.no_grad():
        emb_dim = get_preds(model(torch.ones([2, 3, 299, 299]).cuda())).shape[-1]
        
    if lossname=='triplet':
        criterion = nn.TripletMarginLoss()
    if lossname=='arcface':
        criterion = losses.ArcFaceLoss(N_CLASSES, emb_dim).cuda()
    if lossname=='cosface':
        criterion = losses.CosFaceLoss(N_CLASSES, emb_dim).cuda()
        
    opt = torch.optim.Adam(model.parameters(), )
    
    for epoch in range(N_EPOCHS):
        if lossname == 'triplet':
            p_bar = tqdm(train_emb_dl)
        else:
            p_bar = tqdm(train_dl)
            
        for batch in p_bar:
            if lossname == 'triplet':
                for key in batch.keys():
                    batch[key] = batch[key].to(device)
            else:
                batch[0] = batch[0].to(device)
                batch[1] = batch[1].to(device)
            
            opt.zero_grad()
            
            if lossname == 'triplet':
                preds = {key: get_preds(model(batch[key])) for key in batch.keys()}
            else:
                preds = model(batch[0])
                
            if lossname == 'triplet':
                loss = criterion(
                    preds['anchor'],
                    preds['positive'],
                    preds['negative']
                )
            else:
                loss = criterion(preds, batch[1])

            running_loss += [loss.item()]
            p_bar.set_description(f'current_loss: {loss.item()}')
            loss.backward()
            opt.step()

        torch.save(model.state_dict(), f'weights/{filename}_{lossname}')
        
        train_dd, val_embs, test_embs = evaluate_emb(model)
        acc, final_preds = calculate_accuracy(val_embs, train_dd, threshold=0.5)
        model.train()
        with open(f'metrics/{filename}_{lossname}.json', 'w') as f:
            f.write(json.dumps({'acc': acc, 'final_preds': final_preds}))
        
    with open(f'metrics/{filename}_{lossname}.pickle', 'wb') as f:
        pickle.dump({'running_loss': running_loss}, f)
        
    return running_loss

In [None]:
stats = []

idx2model = {0: 'resnet', 0: 'inception'}

for loss in ['triplet', 'cosface', 'arcface']:
    model1 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
    model2 = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
    
    model1.fc = nn.Linear(in_features=512, out_features=N_CLASSES, bias=True)
    model2.fc = nn.Linear(in_features=2048, out_features=N_CLASSES, bias=True)
    
    model1.load_state_dict(torch.load('weights/resnet3'))
    model2.load_state_dict(torch.load('weights/inception_49'))

    model1.fc         = nn.Identity(512)
    model2.fc         = nn.Identity(2048)
    
    for idx, model in enumerate([model1, model1]):
        print(idx2model[idx], loss)
        emb_train_loop(model, f'{idx2model[idx]}_emb', lossname=loss, N_EPOCHS=50)