In [1]:
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader

import os
import numpy as np
from PIL import Image
from tqdm import tqdm
from collections import defaultdict
from sklearn.metrics import f1_score, confusion_matrix, precision_score, recall_score
import pickle
import json

from transformers import CLIPProcessor, CLIPModel

In [2]:
N_CLASSES = 3000

In [3]:
from transformers import ViTFeatureExtractor, ViTForImageClassification
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
model.classifier = nn.Linear(768, N_CLASSES)

In [4]:
def transform(img):
    img = np.array(img)
    width, height, _ = img.shape
    
    T = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        transforms.CenterCrop(min(width, height)),
        transforms.Resize((224, 224))
    ])
    
    return T(img)

In [5]:
class CelebaClassificationDataset(Dataset):
    def __init__(self, mode='train', class_mapping=None, prev_class_mapping=None):
        with open('identity_CelebA.txt') as f:
            label_mapping = dict([line.replace('\n', '').split(' ') for line in f.readlines()])
            self.img2class = {key: int(value) for key, value in label_mapping.items()}
            
        self.train_classes = sorted(list(set([self.img2class[_] for _ in os.listdir('train')])))
        self.val_classes   = sorted(list(set([self.img2class[_] for _ in os.listdir('val') if \
                                             self.img2class[_] not in self.train_classes])))
        self.test_classes  = sorted(list(set([self.img2class[_] for _ in os.listdir('test') if \
                                  (self.img2class[_] not in self.train_classes)\
                                              and (self.img2class[_] not in self.val_classes)])))
        
        self.all_classes = self.train_classes + self.val_classes + self.test_classes
        
        self.class_mapping = {idx_old:idx_new for idx_new, idx_old in enumerate(self.all_classes)}
            
        self.images = sorted(os.listdir(mode))
        self.img2class = {key: self.class_mapping.get(value) for key, value in self.img2class.items()}
        self.images = [f'{mode}/{_}' for _ in self.images]
        self.img2class = {f'{mode}/{key}': value for key, value in self.img2class.items()}
        
        if prev_class_mapping:
            prev_class_mapping = json.loads(open(prev_class_mapping, 'r').read())
            for key, value in prev_class_mapping.items():
                if value != N_CLASSES + 1:
                    self.img2class[key.replace('train/', f'{mode}/')] = value

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(self.images[idx])
        img = transform(img)
        label = self.img2class[self.images[idx]]
        return img, label

In [6]:
with open('img2class.json', 'w') as f:
    f.write(json.dumps(train_ds.img2class))

In [7]:
batch_size = 16

train_ds = CelebaClassificationDataset('train', prev_class_mapping='img2class.json')
val_ds   = CelebaClassificationDataset('val'  , prev_class_mapping='img2class.json')
test_ds  = CelebaClassificationDataset('test' , prev_class_mapping='img2class.json')

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_dl   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)
test_dl  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

In [9]:
device = 'cuda'

In [10]:
if not os.path.exists('metrics'):
    os.mkdir('metrics')

def scores(preds_labels):
    accuracy  = np.mean([_[0] == _[1] for _ in preds_labels])
    fscore    = f1_score([_[1] for _ in preds_labels], [_[0] for _ in preds_labels], average='macro')
    conf_mat  = confusion_matrix([_[1] for _ in preds_labels], [_[0] for _ in preds_labels])
    precision = precision_score([_[1] for _ in preds_labels], [_[0] for _ in preds_labels], average='macro')
    recall    = recall_score([_[1] for _ in preds_labels], [_[0] for _ in preds_labels], average='macro')
    
    return {'accuracy': accuracy,
            'fscore': fscore,
            'conf_mat': conf_mat,
            'precision': precision,
            'recall': recall,
           }

def evaluate(model):
    model = model.to(device)
    preds  = []
    golden = []
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(val_dl):
            images = images.to(device)
            labels = labels.to(device)
            pred = model(images).logits

            preds  += [torch.argmax(_).item() for _ in pred]
            golden += [_.item() for _ in labels]
    model.train()
    
    preds_labels = [(pred, label) for pred, label in zip(preds, golden) if label != N_CLASSES + 1]
    metrics = scores(preds_labels)
    return metrics

def class_train_loop(model, filename, chkpt=None, N_EPOCHS=100):
    bad_epoch_counter = 0
    running_loss = []
    running_metrics = []
    last_metrics = {
        'accuracy': 0,
        'fscore': 0,
        'conf_mat': 0,
        'precision': 0,
        'recall': 0,
    }
    
    if chkpt:
        model.load_state_dict(torch.load(chkpt))
    
    model = model.cuda()
    criterion = nn.CrossEntropyLoss()
    opt = torch.optim.Adam(model.parameters(), lr=2e-5)
    
    for epoch in range(N_EPOCHS):
        p_bar = tqdm(train_dl)
        for images, labels in p_bar:
            images = images.to(device)
            labels = labels.to(device)
            
            opt.zero_grad()
            preds = model(images)
            
            loss = criterion(preds.logits, labels)
            
            running_loss += [loss.item()]
            p_bar.set_description(f'current_loss: {loss.item()}')
            loss.backward()
            opt.step()

        metrics = evaluate(model)
        running_metrics += [metrics]
        
        if metrics['accuracy'] >= max([_['accuracy'] for _ in running_metrics]):
            torch.save(model.state_dict(), f'weights/{filename}')
            
        if metrics['accuracy'] < last_metrics['accuracy']:
            bad_epoch_counter += 1
        elif (metrics['accuracy'] - last_metrics['accuracy']) < 0.001 * metrics['accuracy']:
            bad_epoch_counter += 1
        else:
            bad_epoch_counter = 0
            
        if bad_epoch_counter == 5:
            with open(f'metrics/{filename}_{epoch}.pickle', 'wb') as f:
                pickle.dump({'running_metrics': running_metrics, 'running_loss': running_loss}, f)
                
            return {'running_metrics': running_metrics, 'running_loss': running_loss}
        
        last_metrics = metrics
        print(f"{epoch} epoch::{metrics['accuracy']}")
        
    with open(f'metrics/{filename}_{epoch}.pickle', 'wb') as f:
        pickle.dump({'running_metrics': running_metrics, 'running_loss': running_loss}, f)
    return {'running_metrics': running_metrics, 'running_loss': running_loss}

In [None]:
class_train_loop(model, 'vit', N_EPOCHS=100)

In [None]:
class CelebaEmbDataset(Dataset):
    def __init__(self, mode='train', class_mapping=None, prev_class_mapping=None):
        with open('identity_CelebA.txt') as f:
            label_mapping = dict([line.replace('\n', '').split(' ') for line in f.readlines()])
            self.img2class = {key: int(value) for key, value in label_mapping.items()}
            
        self.train_classes = sorted(list(set([self.img2class[_] for _ in os.listdir('train')])))
        self.val_classes   = sorted(list(set([self.img2class[_] for _ in os.listdir('val') if \
                                             self.img2class[_] not in self.train_classes])))
        self.test_classes  = sorted(list(set([self.img2class[_] for _ in os.listdir('test') if \
                                  (self.img2class[_] not in self.train_classes)\
                                              and (self.img2class[_] not in self.val_classes)])))
        
        self.all_classes = self.train_classes + self.val_classes + self.test_classes
        
        self.class_mapping = {idx_old:idx_new for idx_new, idx_old in enumerate(self.all_classes)}
            
        self.images = sorted(os.listdir(mode))
        self.img2class = {key: self.class_mapping.get(value) for key, value in self.img2class.items()}
        self.images = [f'{mode}/{_}' for _ in self.images]
        self.img2class = {f'{mode}/{key}': value for key, value in self.img2class.items()}
        
        if prev_class_mapping:
            prev_class_mapping = json.loads(open(prev_class_mapping, 'r').read())
            for key, value in prev_class_mapping.items():
                if value != N_CLASSES + 1:
                    self.img2class[key.replace('train/', f'{mode}/')] = value
        
        
        self.dd = defaultdict(list)
        for img in self.images:
            self.dd[self.img2class[img]] += [img]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        anchor = Image.open(self.images[idx])
        anchor = transform(anchor)
        label = self.img2class[self.images[idx]]
        
        positive = np.random.choice(self.dd[label])
        positive = transform(Image.open(positive))
        
        negative = np.random.choice(self.dd[np.random.choice([idx for idx in range(300) if idx != label])])
        negative = transform(Image.open(negative))
        
        return {'anchor': anchor, 'positive': positive, 'negative': negative}

In [None]:
batch_size = 16

train_emb_ds = CelebaEmbDataset('train', prev_class_mapping='img2class.json')
val_emb_ds   = CelebaEmbDataset('val',   prev_class_mapping='img2class.json')
test_emb_ds  = CelebaEmbDataset('test',  prev_class_mapping='img2class.json')

train_emb_dl = DataLoader(train_emb_ds, batch_size=batch_size, shuffle=True)
val_emb_dl   = DataLoader(val_emb_ds,   batch_size=batch_size, shuffle=False)
test_emb_dl  = DataLoader(test_emb_ds,  batch_size=batch_size, shuffle=False)

In [None]:
model.classifier = nn.Identity(768)

In [None]:
from pytorch_metric_learning import losses

def emb_train_loop(model, filename, chkpt=None, N_EPOCHS=15, lossname='triplet'):
    running_loss = []
    
    def get_preds(preds):
        if type(preds) == models.inception.InceptionOutputs:
            return preds.logits
        else:
            return preds
    
    if chkpt:
        model.load_state_dict(torch.load(chkpt))
    
    model = model.to(device)
    with torch.no_grad():
        emb_dim = model(torch.ones([1, 3, 224, 224]).cuda()).logits.shape[-1]
        
    if lossname=='triplet':
        criterion = nn.TripletMarginLoss()
    if lossname=='arcface':
        criterion = losses.ArcFaceLoss(N_CLASSES, emb_dim).cuda()
    if lossname=='cosface':
        criterion = losses.CosFaceLoss(N_CLASSES, emb_dim).cuda()
        
    opt = torch.optim.Adam(model.parameters(), lr=3e-4)
    
    for epoch in range(N_EPOCHS):
        if lossname == 'triplet':
            p_bar = tqdm(train_emb_dl)
        else:
            p_bar = tqdm(train_dl)
            
        for batch in p_bar:
            if lossname == 'triplet':
                for key in batch.keys():
                    batch[key] = batch[key].to(device)
            else:
                batch[0] = batch[0].to(device)
                batch[1] = batch[1].to(device)
            
            opt.zero_grad()
            
            if lossname == 'triplet':
                preds = {key: get_preds(model(batch[key])) for key in batch.keys()}
            else:
                preds = model(batch[0])
                
            if lossname == 'triplet':
                loss = criterion(
                    preds['anchor'].logits,
                    preds['positive'].logits,
                    preds['negative'].logits
                )
            else:
                loss = criterion(preds, batch[1])

            running_loss += [loss.item()]
            p_bar.set_description(f'current_loss: {loss.item()}')
            loss.backward()
            opt.step()

        torch.save(model.state_dict(), f'weights/{filename}_{lossname}')
        
    with open(f'metrics/{filename}_{lossname}.pickle', 'wb') as f:
        pickle.dump({'running_loss': running_loss}, f)
        
    return running_loss

In [None]:
def calculate_embs(model, dl):
    preds  = []
    golden = []
    
    with torch.no_grad():
        for images, labels in tqdm(dl):
            images = images.to(device)
            labels = labels.to(device)
            pred = model(images).logits

            preds  += [_.detach() for _ in pred]
            golden += [_.item() for _ in labels]
    
    preds_labels = list(zip(preds, golden))
    return preds_labels

def calculate_accuracy(val_embs, train_dd, threshold):
    def norm(x):
        return x / torch.sqrt(torch.sum(x ** 2))
    
    cs = nn.CosineSimilarity(dim=0)

    val_preds_real = []
    all_scores = []

    for val_emb, val_face in tqdm(val_embs):
        face_score = []
        for face, emb in train_dd.items():
            face_score += [(face, cs(val_emb, emb))]

        face_score = sorted(face_score, key = lambda x: -x[1])

        if face_score[0][1] > threshold:
            val_preds_real += [[face_score[0][0], val_face]]
            all_scores += [{'score': face_score[0][1], 'pred': face_score[0][0], 'true': val_face}]
        else:
            if val_face in train_dd:
                val_preds_real += [[-1, val_face]]
                all_scores += [{'score': face_score[0][1], 'pred': -1, 'true': val_face}]
            else:
                train_dd[val_face] = val_emb
                val_preds_real += [[val_face, val_face]]
                all_scores += [{'score': face_score[0][1], 'pred': val_face, 'true': val_face}]
                
    return np.mean([_[0] == _[1] for _ in val_preds_real])

def evaluate_emb(model, threshold=0.7, mode='val'):
    model = model.to(device)
    model.eval()
    
    def norm(x):
        return x / torch.sqrt(torch.sum(x ** 2))
    
    train_embs = calculate_embs(model, train_dl)
    val_embs   = calculate_embs(model, val_dl)
    test_embs  = calculate_embs(model, test_dl)

    train_dd = defaultdict(list)
    
    for emb, face in train_embs:
        train_dd[face] += [emb]
        
    for face, embs in train_dd.items():
        train_dd[face] = torch.mean(torch.vstack([norm(emb) for emb in embs]), dim=0)
        
    if mode == 'val':
        result = calculate_accuracy(val_embs, train_dd, threshold)
    else:
        result = calculate_accuracy(test_embs, train_dd, threshold)

    return result

In [None]:
stats = []

idx2model = {0: 'vit'}

for loss in ['triplet', 'cosface', 'arcface']:
    model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
    model.classifier = nn.Linear(768, N_CLASSES)
    model.load_state_dict(torch.load('weights/vit'))
    model.classifier = nn.Identity(768)
    
    for idx, model in enumerate([model]):
        emb_train_loop(model, f'{idx2model[idx]}_emb', lossname=loss, N_EPOCHS=15)
        
        thr_score = []
#         for thr in [0.65, 0.68, 0.7, 0.72, 0.75, 0.78, 0.8, 0.82, 0.85]:
        for thr in [0.65, 0.7, 0.75, 0.8, 0.85]:
            res = [thr, evaluate_emb(model, threshold=thr, mode='val')]
            thr_score += [res]
            string = f"thr: {res[0]}, {res[1]}, mode==val"
            print(string)
            
        best_thr = sorted(thr_score, key=lambda x: -x[1])[0][0]
        test_score = evaluate_emb(model, threshold=0.7, mode='val')
        
        stats += [{'test_score': test_score, 
                   'best_thr': best_thr, 
                   'val_scores': res, 
                   'lossname': loss, 
                   'model': idx2model[idx]}]
        
        print(stats)
            
        with open('trans_stats.json', 'w') as f:
            f.write(json.dumps(stats))