In [1]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, ToPILImage, Resize, Compose, Normalize
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from torchvision.models import vgg11, vgg16, resnet18, vgg16_bn
from tensorboardX import SummaryWriter
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.utils as vutils

In [2]:
import os
from collections import defaultdict
from tqdm import tqdm_notebook as tqdm
from PIL import Image as pilimage
import io
from PIL import Image

import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import random
from IPython.display import clear_output 
from collections import Counter

In [3]:
def get_grad_norm(parameters, norm_type=2):
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    norm_type = float(norm_type)
    total_norm = 0
    for p in parameters:
        param_norm = p.grad.data.norm(norm_type)
        total_norm += param_norm.item() ** norm_type
    total_norm = total_norm ** (1. / norm_type)
    return total_norm

In [4]:
# for i in range(5):
#     !wget --user adiencedb --password adience http://www.cslab.openu.ac.il/download/adiencedb/AdienceBenchmarkOfUnfilteredFacesForGenderAndAgeClassification/fold_{i}_data.txt

In [5]:
# for i in range(5):
#     !mv ./fold_{i}_data.txt ../data/aligned_labels/

In [6]:
import pandas as pd

labels = pd.concat([
        pd.read_csv('../data/aligned_labels/fold_{}_data.txt'.format(i), delimiter='\t')
    for i in range(5)]).dropna().reset_index(drop=True)

In [7]:
from PIL import Image

def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

def rapply(trans):
    return transforms.RandomApply([trans], 0.3)

In [8]:
class AgeFolder(Dataset):
    def __init__(self, labels):
        self.classes = {cl: i for i, cl in enumerate([i[0] for i in Counter(labels['age']).most_common(8)])}
        self.labels = labels[labels['age'].isin(self.classes)].reset_index(drop=True)
        self.loader = pil_loader
        self.transform = Compose([Resize((224, 224)), ToTensor()])
        self.images = {index: self.transform(self.loader(self.get_image_path(index))) for index in tqdm(self.labels.index)}
        self.train_transform = Compose([ToPILImage(),
                                        rapply(transforms.RandomRotation(180)),
                                        rapply(transforms.RandomAffine(180)),
                                        rapply(transforms.RandomHorizontalFlip(0.5)),
                                        rapply(transforms.RandomVerticalFlip(0.5)),
                                        ToTensor()])
        self.train = True

    def __len__(self):
        return self.labels.shape[0]
        
    def get_image_path(self, index):
        row = self.labels.iloc[index]
        return '../data/aligned/{}/landmark_aligned_face.{}.{}'.format(row['user_id'], row['face_id'], row['original_image'])
    
    def __getitem__(self, index):
        image = self.images[index]
        image_class = self.classes[self.labels.iloc[index, 3]]
        if self.train:
            image = self.train_transform(image)
        return image, image_class

In [9]:
age_set = AgeFolder(labels)
dataset_size = len(age_set)
indices = list(range(dataset_size))
split = int(np.floor(0.2 * dataset_size))
np.random.seed(3)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(age_set, batch_size=32, 
                                           sampler=train_sampler, num_workers=4)
validation_loader = torch.utils.data.DataLoader(age_set, batch_size=32,
                                                sampler=valid_sampler)


HBox(children=(IntProgress(value=0, max=17327), HTML(value='')))




In [10]:
age_model = vgg16_bn(pretrained=True)
age_model.classifier[6] = nn.Linear(4096, len(age_set.classes))
age_model.cuda()
optim = torch.optim.Adam(age_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
writer = SummaryWriter('../logs/age_model/vgg16_bn_pretrained.v1')
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', factor=0.5, patience=3)

In [11]:
class EarlyStopping:
    def __init__(self, name):
        self.lowest = 100
        self.lowest_model = '../data/models/'+name
        self.continuity = 0
        self.prev_value = 0
        
    def step(self, value, model):
        if value < self.lowest:
            torch.save(model, self.lowest_model)
            self.lowest = value
        if value > self.prev_value:
            self.continuity += 1
        else:
            self.continuity = 0
        self.prev_value = value
        if self.continuity >= 5:
            return True
        return False

In [12]:
i2class = {i: cl for i, cl in enumerate([i[0] for i in Counter(labels['age']).most_common(8)])}

In [13]:
def gen_plot(image_tensor, age, real_age):
    plt.clf()
    buf = io.BytesIO()
    plt.axis('off')
    plt.imshow(image_tensor.cpu().data.numpy().swapaxes(0, 2).swapaxes(0, 1))
    plt.text(5, 20, 'real = {}, pred = {}'.format(str(i2class[real_age]), str(i2class[age])), fontsize=14, color='red')
    plt.savefig(buf, format='jpg')
    buf.seek(0)
    plt.clf()
    return ToTensor()(Image.open(buf))

In [14]:
early_stopper = EarlyStopping('age_model.pth')

In [15]:
for epoch in tqdm(range(100)):
    losses = []
    age_model.train()
    train_loader.dataset.train = True
    for batch_ind, (image, target) in enumerate(tqdm(train_loader, leave=False)):
        optim.zero_grad()
        image, target = image.cuda(), target.cuda()
        pred = age_model(image)
        loss = criterion(pred, target)
        loss.backward()
        writer.add_scalar('train/grad_norm', get_grad_norm(age_model.parameters()), epoch*len(train_loader)+batch_ind)
        torch.nn.utils.clip_grad_norm_(age_model.parameters(), 1)
        optim.step()
        writer.add_scalar('train/batch', loss.item(), epoch*len(train_loader)+batch_ind)
        losses.append(loss.item())
    writer.add_scalar('train/epoch', np.mean(losses), epoch)
    age_model.eval()
    validation_loader.dataset.train = False
    losses = []
    for batch_ind, (image, target) in enumerate(tqdm(validation_loader, leave=False, desc='val')):
        image, target = image.cuda(), target.cuda()
        pred = age_model(image)
        loss = criterion(pred, target)
        writer.add_scalar('val/batch', loss.item(), epoch*len(validation_loader)+batch_ind)
        losses.append(loss.item())
    writer.add_scalar('val/epoch', np.mean(losses), epoch)
    scheduler.step(np.mean(losses))
    stopping = early_stopper.step(np.mean(losses), age_model)
    image, target = next(iter(validation_loader))        
    image = image[:5]
    pred = torch.max(age_model(image.cuda()).cpu(), dim=1)[1]
    real = target[:5]
    writer.add_image('age_person', vutils.make_grid(torch.stack([gen_plot(im, int(pr), int(rl)) for im, pr, rl in zip(image,
                                                                                                      pred,
                                                                                                      real)])), epoch)
    if stopping:
        print('Finished ad epoch {}'.format(epoch))
    

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, max=434), HTML(value='')))

Process Process-104:
Process Process-102:
Process Process-103:
Process Process-101:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/mrartemev/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/mrartemev/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/mrartemev/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/mrartemev/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/mrartemev/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/mrartemev/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeou

KeyboardInterrupt: 

<Figure size 432x288 with 0 Axes>

In [13]:
age_model = torch.load('../data/models/age_model.pth')
criterion = nn.CrossEntropyLoss()


In [14]:
age_model.eval()
validation_loader.dataset.train = False
losses = []
for batch_ind, (image, target) in enumerate(tqdm(validation_loader, leave=False, desc='val')):
    image, target = image.cuda(), target.cuda()
    pred = age_model(image)
    loss = criterion(pred, target)
    losses.append(loss.item())

HBox(children=(IntProgress(value=0, description='val', max=109, style=ProgressStyle(description_width='initial…



In [15]:
np.mean(losses)

0.5548595805233771