In [1]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
# from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, ToPILImage, Resize, Compose, Normalize
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from torchvision.models import vgg11, vgg16, resnet18, vgg16_bn, vgg11_bn, resnet50
from tensorboardX import SummaryWriter
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.utils as vutils

In [2]:
import os
from collections import defaultdict
from tqdm import tqdm_notebook as tqdm
from PIL import Image as pilimage

import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import random
from IPython.display import clear_output 

In [3]:
from ImageFolderPreloader import ImageFolder

def rapply(trans):
    return transforms.RandomApply([trans], 0.3)

train_transform = Compose([ToPILImage(),
                            rapply(transforms.RandomRotation(180)),
                            rapply(transforms.RandomAffine(180)),
                            rapply(transforms.RandomHorizontalFlip(0.5)),
                            rapply(transforms.RandomVerticalFlip(0.5)),
                            ToTensor()])

In [4]:
import pickle

with open('../data/resnet50_ft_weight.pkl', 'rb') as f:
    weights = pickle.load(f, encoding='latin1')

In [5]:
identity_model = resnet50(num_classes=8631)

In [6]:
own_state = identity_model.state_dict()
for name, param in weights.items():
    if name in own_state:
        own_state[name].copy_(torch.from_numpy(param))


In [4]:
lfw_dset = ImageFolder('../data/lfw-deepfunneled/',
                       transform=Compose([Resize((224, 224)), ToTensor()]),
                       train_transform=train_transform)
dataset_size = len(lfw_dset)
indices = list(range(dataset_size))
split = int(np.floor(0.2 * dataset_size))
np.random.seed(3)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(lfw_dset, batch_size=16, 
                                           sampler=train_sampler, num_workers=4)
validation_loader = torch.utils.data.DataLoader(lfw_dset, batch_size=16,
                                                sampler=valid_sampler)


HBox(children=(IntProgress(value=0, max=7606), HTML(value='')))




In [5]:
def get_grad_norm(parameters, norm_type=2):
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    norm_type = float(norm_type)
    total_norm = 0
    for p in parameters:
        param_norm = p.grad.data.norm(norm_type)
        total_norm += param_norm.item() ** norm_type
    total_norm = total_norm ** (1. / norm_type)
    return total_norm

In [9]:
num_classes = len(lfw_dset.classes)
# identity_model = vgg11_bn(pretrained=True)
identity_model.fc = nn.Linear(2048, num_classes)
# identity_model.classifier[6] = nn.Linear(4096, num_classes)
identity_model.cuda()
criterion = nn.CrossEntropyLoss()
optim = torch.optim.Adam(identity_model.parameters(), lr=1e-5)
writer = SummaryWriter('../logs/identity_model/resnet50_pretrained.v1')
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', factor=0.5, patience=3)

In [10]:
class EarlyStopping:
    def __init__(self, name):
        self.lowest = 100
        self.lowest_model = '../data/models/'+name
        self.continuity = 0
        self.prev_value = 0
        
    def step(self, value, model):
        if value < self.lowest:
            torch.save(model, self.lowest_model)
            self.lowest = value
        if value > self.prev_value:
            self.continuity += 1
        else:
            self.continuity = 0
        self.prev_value = value
        if self.continuity >= 5:
            return True
        return False

In [11]:
early_stopper = EarlyStopping('identity_model.pth')

In [12]:
for epoch in tqdm(range(100), desc='epoch'):
    losses = []
    identity_model.train()
    train_loader.dataset.train = True
    for batch_ind, (photo, target) in enumerate(tqdm(train_loader, leave=False, desc='train')):
        optim.zero_grad()
        photo, target = photo.cuda(), target.cuda()
        pred = identity_model(photo)
        loss = criterion(pred, target)
        loss.backward()
        writer.add_scalar('train/grad_norm', get_grad_norm(identity_model.parameters()), epoch*len(train_loader)+batch_ind)
        torch.nn.utils.clip_grad_norm_(identity_model.parameters(), 1)
        optim.step()
        writer.add_scalar('train/batch', loss.item(), epoch*len(train_loader)+batch_ind)
        losses.append(loss.item())
    writer.add_scalar('train/epoch', np.mean(losses), epoch)
    
    identity_model.eval()
    validation_loader.dataset.train = False

    identities = np.random.randint(0, len(lfw_dset.classes), size=5)
    identities_ans = [[], [], [], [], []]
    losses = []
    for batch_ind, (photo, target) in enumerate(tqdm(validation_loader, leave=False, desc='val')):
        photo, target = photo.cuda(), target.cuda()
        pred = identity_model(photo)
        
        for ind, ident in enumerate(identities):
            for i in photo[torch.nonzero(torch.max(pred, dim=1)[1] == int(ident))]:
                identities_ans[ind].append(i[0])

        loss = criterion(pred, target)
        writer.add_scalar('val/batch', loss.item(), epoch*len(validation_loader)+batch_ind)
        losses.append(loss.item())
    writer.add_scalar('val/epoch', np.mean(losses), epoch)
    early_stopper.step(np.mean(losses), identity_model)
    scheduler.step(np.mean(losses))
    for ind in range(5):
        try:
            writer.add_image('person_{}'.format(ind), vutils.make_grid(torch.stack(identities_ans[ind][:4]), nrow=4), epoch)
        except RuntimeError:
            pass
        

HBox(children=(IntProgress(value=0, description='epoch', style=ProgressStyle(description_width='initial')), HT…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=381, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=96, style=ProgressStyle(description_width='initial'…




In [14]:
torch.save(identity_model, '../data/identity_model.final')

In [6]:
identity_model = torch.load('../data/models/identity_model.pth')

In [7]:
class TripletDataset(Dataset):
    def __init__(self, lfw_dset):
        self.lfw_dset = lfw_dset
        self.images = defaultdict(list)
        for img, class_ind in lfw_dset.imgs:
            self.images[class_ind].append(img)
        self.anchors = {class_ind: self.images[class_ind] for class_ind in self.images.keys() if len(self.images[class_ind]) >= 3}
        
        
    def __len__(self):
        return len(self.anchors)
    
    def __getitem__(self, ind):
        far = random.sample(self.images.keys(), 1)[0]
        anchor, close = random.sample(self.anchors[ind], 2)
        far = random.sample(self.images[far], 1)[0]
        return self.lfw_dset.transform(self.lfw_dset.loader(anchor)), \
               self.lfw_dset.transform(self.lfw_dset.loader(close)), \
               self.lfw_dset.transform(self.lfw_dset.loader(far))
        

In [8]:
# for i in next(iter(TripletDataset(lfw_dset))):
#     plt.imshow(i.numpy().swapaxes(0, 2).swapaxes(0, 1))
#     plt.show()

In [10]:
identity_model.fc = nn.Sequential().cuda()

In [9]:
# identity_model.fc = nn.Linear(in_features=2048, out_features=512, bias=True).cuda()

In [11]:
triplet_dataset = TripletDataset(lfw_dset)
dataset_size = len(triplet_dataset)
indices = list(range(dataset_size))
split = int(np.floor(0.2 * dataset_size))
np.random.seed(3)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(triplet_dataset, batch_size=4, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(triplet_dataset, batch_size=4,
                                                sampler=valid_sampler)


In [12]:
criterion = torch.nn.MarginRankingLoss(margin = 0.2) # mb change
optim = torch.optim.Adam(identity_model.parameters(), lr=1e-5)
writer = SummaryWriter('../logs/identity_model/resnet50_triplet.v4')

In [13]:
epoch_losses = []
identity_model.train()

for epoch in tqdm(range(100), desc='epoch'):
    losses = []
    for batch_ind, (anchor, close, far) in enumerate(tqdm(train_loader, leave=False, desc='train')):
        anchor, close, far = anchor.cuda(), close.cuda(), far.cuda()
        optim.zero_grad()
        anchor, close, far = identity_model(anchor), \
                             identity_model(close), \
                             identity_model(far)
        target = torch.FloatTensor(anchor.size(0)).fill_(1).cuda()
        loss = criterion(torch.pow(F.pairwise_distance(anchor, far), 2), \
                         torch.pow(F.pairwise_distance(anchor, close), 2), \
                         target)
        writer.add_scalar('train/batch', loss.item(), epoch*len(train_loader)+batch_ind)
        loss.backward()
        writer.add_scalar('train/grad_norm', get_grad_norm(identity_model.parameters()), epoch*len(train_loader)+batch_ind)
        torch.nn.utils.clip_grad_norm_(identity_model.parameters(), 1)
        optim.step()
        losses.append(loss.item())
    writer.add_scalar('train/epoch', np.mean(losses), epoch)

    losses = []
    for batch_ind, (anchor, close, far) in enumerate(tqdm(validation_loader, leave=False, desc='val')):
        anchor, close, far = anchor.cuda(), close.cuda(), far.cuda()
        anchor, close, far = identity_model(anchor), \
                             identity_model(close), \
                             identity_model(far)
        target = torch.FloatTensor(anchor.size(0)).fill_(1).cuda()
        loss = criterion(torch.pow(F.pairwise_distance(anchor, far), 2), \
                         torch.pow(F.pairwise_distance(anchor, close), 2), \
                         target)
        losses.append(loss.item())
        writer.add_scalar('val/batch', loss.item(), epoch*len(validation_loader)+batch_ind)
        losses.append(loss.item())
    writer.add_scalar('val/epoch', np.mean(losses), epoch)
    anchor, close, far = next(iter(validation_loader))
    anchor_pr, close_pr, far_pr = identity_model(anchor.cuda()).cpu().detach(), \
                                     identity_model(close.cuda()).cpu().detach(), \
                                     identity_model(far.cuda()).cpu().detach()
    dists_far = torch.pow(F.pairwise_distance(anchor_pr, far_pr), 2)
    dists_close = torch.pow(F.pairwise_distance(anchor_pr, close_pr), 2)
    for ind in range(4):
        if dists_far[ind] > dists_close[ind]:
            writer.add_image('person_{}'.format(ind), vutils.make_grid(torch.stack([anchor[ind],
                                                                                    close[ind],
                                                                                    far[ind],
                                                                                    anchor[ind],
                                                                                    close[ind],
                                                                                    far[ind]], dim=0), nrow=6), epoch)
        else:
            writer.add_image('person_{}'.format(ind), vutils.make_grid(torch.stack([anchor[ind],
                                                                                    close[ind],
                                                                                    far[ind],
                                                                                    anchor[ind],
                                                                                    far[ind],
                                                                                    close[ind]], dim=0), nrow=6), epoch)

            

HBox(children=(IntProgress(value=0, description='epoch', style=ProgressStyle(description_width='initial')), HT…

HBox(children=(IntProgress(value=0, description='train', max=181, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=45, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=181, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=45, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=181, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=45, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=181, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=45, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=181, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=45, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=181, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=45, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=181, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='val', max=45, style=ProgressStyle(description_width='initial'…

HBox(children=(IntProgress(value=0, description='train', max=181, style=ProgressStyle(description_width='initi…

KeyboardInterrupt: 

In [29]:
next(iter(validation_loader))

[tensor([[[[0.0196, 0.0196, 0.0196,  ..., 0.0314, 0.0157, 0.0039],
           [0.2000, 0.2000, 0.1882,  ..., 0.0392, 0.0157, 0.0039],
           [0.7412, 0.7294, 0.7059,  ..., 0.0314, 0.0157, 0.0039],
           ...,
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0039, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
 
          [[0.0039, 0.0039, 0.0039,  ..., 0.0196, 0.0039, 0.0000],
           [0.1843, 0.1843, 0.1725,  ..., 0.0275, 0.0039, 0.0000],
           [0.7255, 0.7137, 0.6902,  ..., 0.0196, 0.0039, 0.0000],
           ...,
           [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0118, 0.0118],
           [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
           [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039]],
 
          [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0196, 0.0510],
           [0.1569, 0.1529, 0.1490,  ..., 0.0000, 0.0196, 0.0510],
           [0.6745, 0.66

In [16]:
torch.save(identity_model, '../data/identity_model.pth')