In [1]:
# set up path to facenet_pytorch_c
import sys
sys.path.insert(1, '/home/ubuntu/mtcnn')

In [2]:
# facenet_pytorch_c: avoid confusion with system default facenet_pytorch
from facenet_pytorch_c import MTCNN

from tqdm import tqdm
import numpy as np
import os

# pytorch
import torch
import torch.optim as optim
from torch import nn

# data handling
from torch.utils.data import DataLoader

# torchvision libs
from torchvision import datasets
from torchvision import transforms

# other custom scripts
import utils

In [13]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Available device: " + str(device))

# training hyperparameters
learning_rate = 1e-4
epochs = 100
decay_step = [15, 30, 45, 60]
decay_rate = 0.1
opt = 'SGD'    # either Adam or SGD
batch_size = 64

print_freq = int(23000/batch_size - 120)
print("print freq: {}".format(print_freq))


Available device: cuda:0
print freq: 239


In [4]:
# data loading parameters
workers = 4
resize_shape = (48, 48)
class_num = 5

In [5]:
# get data
x_train, age_train, fn_train, bbox_train, prob_train, land_train, x_valid, age_valid, fn_valid, bbox_valid, prob_valid, land_valid = utils.get_images(
    r'/home/ubuntu/UTKFace', age_thresh=(18, 35, 55, 75), resize_shape=resize_shape, label_mode=False
)

100%|██████████| 23708/23708 [02:29<00:00, 158.71it/s]

Ignored images: 





In [6]:
# setup mtcnn

mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True, keep_all=True,
    device=device
)

In [14]:
# define data reader

# no need to convert to PIL, because get_images already does that
# also, disabling image normalization for now

# note: horizontal flip must be disabled, or else mtcnn bbox labels would be invalidated

transform_train = transforms.Compose([
    #transforms.ToPILImage(),
    #transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])



transform_valid = transforms.Compose([
    #transforms.ToPILImage(),
    transforms.ToTensor(),
    #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

train_ds = utils.UTK_dataset(
    x_train, age_train, bbox_train, prob_train,
    land_train, trsfm=transform_train
)

valid_ds = utils.UTK_dataset(
    x_valid, age_valid, bbox_valid, prob_valid,
    land_valid, trsfm=transform_valid
)


train_loader = DataLoader(
    train_ds,
    batch_size=batch_size, num_workers=workers, shuffle=True
)


valid_loader = DataLoader(
    valid_ds,
    batch_size=batch_size, num_workers=workers, shuffle=False
)


In [15]:
# train ONet
import matplotlib.pyplot as plt
import torchvision
import PIL

from tensorboardX import SummaryWriter

writer = SummaryWriter(log_dir="/home/ubuntu/tensorLog") # tensorboard writer

mtcnn.train(); # semicolon here, to suppress unnecessary pytorch output
mtcnn.onet.train()
mtcnn.onet.to(device)

prob_lossfn = nn.BCELoss().to(device)
bbox_lossfn = nn.MSELoss().to(device)
landmarks_lossfn = nn.MSELoss().to(device)
age_lossfn = nn.CrossEntropyLoss().to(device)


optimizer = None

if opt == "Adam":
    print("Optimizer: Adam")
    optimizer = torch.optim.Adam(mtcnn.onet.parameters(), lr=learning_rate, amsgrad=True)
elif opt == "SGD":
    print("Optimizer: SGD")
    optimizer = torch.optim.SGD(mtcnn.onet.parameters(), lr=learning_rate, momentum=0.9)
else:
    print("Error")

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=decay_step, gamma=decay_rate)

rl1, rl2, rl3, rl4 = 0, 0, 0, 0


for epoch in range(1, epochs+1):
    for batch_idx, data in enumerate(train_loader):
        im, age, bbox, prob, landmarks = data
        
        im = im.to(device)
        age = age.to(device)
        bbox = bbox.to(device)
        prob = prob.to(device)
        landmarks = landmarks.to(device)
        
        o_bbox, o_landmarks, o_prob, o_age = mtcnn.onet(im)
        
        prob_loss = prob_lossfn(o_prob, prob)
        bbox_loss = bbox_lossfn(o_bbox, bbox)
        landmarks_loss = landmarks_lossfn(o_landmarks, landmarks)
        age_loss = age_lossfn(o_age, age)
        
        rl1 += age_loss.item()
        rl2 += prob_loss.item()
        rl3 += landmarks_loss.item()
        rl4 += bbox_loss.item()
        
        
        all_loss = prob_loss*0.2 + bbox_loss*0.2 + landmarks_loss*0.2 + age_loss*1.0
        
        if batch_idx % print_freq == print_freq-1:
            print("epoch: {} agel: {:.4f} probl: {:.4f} landl: {:.4f} bboxl: {:.4f}".format(epoch, rl1/print_freq, rl2/print_freq,rl3/print_freq,rl4/print_freq,))
            writer.add_scalar('age_l', rl1/print_freq, epoch)
            writer.add_scalar('prob_l', rl2/print_freq, epoch)
            writer.add_scalar('land_l', rl3/print_freq, epoch)
            writer.add_scalar('bbox_l', rl4/print_freq, epoch)
            rl1, rl2, rl3, rl4 = 0, 0, 0, 0

        optimizer.zero_grad()
        all_loss.backward()
        optimizer.step()
        scheduler.step()
    
    if epoch % 5 == 0:
        save_name = 'epoch-{}.pth.tar'.format(epoch)
        torch.save({
            'epoch': epochs,
            'state_dict': mtcnn.onet.state_dict(),
            'opt_dict': optimizer.state_dict(),
        }, save_name)
        print('Saved model at {}'.format(save_name))

print("finished training")

Optimizer: SGD
epoch: 1 agel: 1.3190 probl: 0.0389 landl: 0.0074 bboxl: 0.0238
epoch: 2 agel: 1.6371 probl: 0.0444 landl: 0.0090 bboxl: 0.0295
epoch: 3 agel: 1.6393 probl: 0.0543 landl: 0.0091 bboxl: 0.0297
epoch: 4 agel: 1.6368 probl: 0.0513 landl: 0.0090 bboxl: 0.0294
epoch: 5 agel: 1.6343 probl: 0.0493 landl: 0.0090 bboxl: 0.0291
Saved model at epoch-5.pth.tar
epoch: 6 agel: 1.6406 probl: 0.0415 landl: 0.0090 bboxl: 0.0296
epoch: 7 agel: 1.6404 probl: 0.0590 landl: 0.0090 bboxl: 0.0295
epoch: 8 agel: 1.6384 probl: 0.0500 landl: 0.0090 bboxl: 0.0293
epoch: 9 agel: 1.6345 probl: 0.0485 landl: 0.0090 bboxl: 0.0297
epoch: 10 agel: 1.6394 probl: 0.0509 landl: 0.0090 bboxl: 0.0293
Saved model at epoch-10.pth.tar
epoch: 11 agel: 1.6361 probl: 0.0501 landl: 0.0089 bboxl: 0.0297
epoch: 12 agel: 1.6364 probl: 0.0424 landl: 0.0090 bboxl: 0.0294
epoch: 13 agel: 1.6337 probl: 0.0576 landl: 0.0090 bboxl: 0.0291
epoch: 14 agel: 1.6430 probl: 0.0471 landl: 0.0090 bboxl: 0.0298
epoch: 15 agel: 1.632

KeyboardInterrupt: 

In [10]:
# test mtcnn performance

import matplotlib.pyplot as plt
import torchvision.transforms
import PIL

age_thresh = (0, 6, 18, 200)
age_str = ['{}~{}'.format(a, b) for (a, b) in zip(age_thresh[:-1], age_thresh[1:])]

total = len(valid_ds)
err = 0
truth = []
pred = []

bins = [0] * class_num

for idx in range(0, total):
    
    #plt.imshow(transforms.ToPILImage()(train_ds[idx][0]))
    #plt.show()
    
    _, _, _, o_age = mtcnn.onet(valid_ds[idx][0].unsqueeze(0).to(device))
    
    bins[age_valid[idx]] += 1
    
    if o_age.max(1)[1].item() != age_valid[idx]:
        err += 1
    
    pred.append(o_age.max(1)[1].item())
    truth.append(age_valid[idx])
    #print("pred: {}, gt: {}".format(o_age.max(1)[1].item(), age_train[idx]))

print("accuracy: {:.1f}%".format(100*(total-err)/total))
p, r, f1 = utils.f1_score(truth, pred, 0)
print("precision: {:.2f}, recall: {:.2f}, f1: {:.2f}".format(p, r, f1))
print("valid_ds length: {}".format(len(valid_ds)))
print("age dist: ", end='')
for k in bins:
    print(str(k)+" ", end='')
print()



accuracy: 48.6%
precision: 0.46, recall: 0.33, f1: 0.39
valid_ds length: 4741
age dist: 880 2166 1030 493 172 
