In [1]:
# set up path to facenet_pytorch_c
import sys
sys.path.insert(1, '/home/ubuntu/mtcnn')

In [2]:
# facenet_pytorch_c: avoid confusion with system default facenet_pytorch
from facenet_pytorch_c import MTCNN

from tqdm import tqdm
import numpy as np
import os

# pytorch
import torch
import torch.optim as optim
from torch import nn

# data handling
from torch.utils.data import DataLoader

# torchvision libs
from torchvision import datasets
from torchvision import transforms

# other custom scripts
import utils

In [3]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Available device: " + str(device))

# training hyperparameters
learning_rate = 1e-3
epochs = 30
decay_step = [15]
decay_rate = 0.1
opt = 'Adam'    # either Adam or SGD
batch_size = 16


Available device: cuda:0


In [4]:
# data loading parameters
workers = 4
resize_shape = (224, 224)

In [5]:
# get data
x_train, age_train, fn_train, bbox_train, prob_train, land_train, x_valid, age_valid, fn_valid, bbox_valid, prob_valid, land_valid = utils.get_images(
    r'/home/ubuntu/UTKFace', resize_shape=resize_shape
)

  1%|          | 8/1000 [00:00<00:12, 76.35it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


  3%|▎         | 26/1000 [00:00<00:12, 79.55it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


  4%|▍         | 44/1000 [00:00<00:11, 80.97it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


  6%|▌         | 62/1000 [00:00<00:11, 82.15it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


  8%|▊         | 80/1000 [00:00<00:11, 82.15it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 10%|▉         | 98/1000 [00:01<00:10, 82.55it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 12%|█▏        | 116/1000 [00:01<00:10, 81.62it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 13%|█▎        | 134/1000 [00:01<00:10, 82.19it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 15%|█▌        | 152/1000 [00:01<00:10, 82.31it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 16%|█▌        | 161/1000 [00:01<00:10, 82.48it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 18%|█▊        | 179/1000 [00:02<00:09, 83.32it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 20%|█▉        | 197/1000 [00:02<00:09, 83.67it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 22%|██▏       | 215/1000 [00:02<00:09, 83.56it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 23%|██▎       | 233/1000 [00:02<00:09, 83.79it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 25%|██▌       | 251/1000 [00:03<00:08, 83.74it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 27%|██▋       | 269/1000 [00:03<00:08, 83.08it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 29%|██▊       | 287/1000 [00:03<00:08, 83.32it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 30%|███       | 305/1000 [00:03<00:08, 83.49it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 31%|███▏      | 314/1000 [00:03<00:08, 83.46it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 33%|███▎      | 332/1000 [00:04<00:08, 83.46it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 35%|███▌      | 350/1000 [00:04<00:07, 83.56it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 37%|███▋      | 368/1000 [00:04<00:07, 83.74it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(2, 5, 2)


 39%|███▊      | 386/1000 [00:04<00:07, 83.56it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 40%|████      | 404/1000 [00:04<00:07, 82.54it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 42%|████▏     | 422/1000 [00:05<00:06, 82.98it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 44%|████▍     | 440/1000 [00:05<00:06, 83.38it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 46%|████▌     | 458/1000 [00:05<00:06, 83.53it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 48%|████▊     | 476/1000 [00:05<00:06, 83.86it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 48%|████▊     | 485/1000 [00:05<00:06, 83.79it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 50%|█████     | 503/1000 [00:06<00:05, 83.86it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 52%|█████▏    | 521/1000 [00:06<00:05, 83.54it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 54%|█████▍    | 539/1000 [00:06<00:05, 83.47it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 56%|█████▌    | 557/1000 [00:06<00:05, 83.44it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 57%|█████▊    | 575/1000 [00:06<00:05, 83.66it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 59%|█████▉    | 593/1000 [00:07<00:04, 83.66it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 61%|██████    | 611/1000 [00:07<00:04, 84.04it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 63%|██████▎   | 629/1000 [00:07<00:04, 84.32it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 64%|██████▍   | 638/1000 [00:07<00:04, 84.30it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 66%|██████▌   | 656/1000 [00:07<00:04, 84.48it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 67%|██████▋   | 674/1000 [00:08<00:03, 84.40it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 69%|██████▉   | 692/1000 [00:08<00:03, 83.39it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 71%|███████   | 710/1000 [00:08<00:03, 83.12it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 73%|███████▎  | 728/1000 [00:08<00:03, 83.30it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 75%|███████▍  | 746/1000 [00:08<00:03, 83.36it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 76%|███████▋  | 764/1000 [00:09<00:02, 83.46it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 78%|███████▊  | 782/1000 [00:09<00:02, 83.26it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 80%|████████  | 800/1000 [00:09<00:02, 83.90it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 82%|████████▏ | 818/1000 [00:09<00:02, 84.31it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 84%|████████▎ | 836/1000 [00:10<00:01, 84.24it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 84%|████████▍ | 845/1000 [00:10<00:01, 83.66it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 86%|████████▋ | 863/1000 [00:10<00:01, 83.47it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 88%|████████▊ | 881/1000 [00:10<00:01, 82.56it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 90%|████████▉ | 899/1000 [00:10<00:01, 81.95it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 92%|█████████▏| 917/1000 [00:11<00:01, 81.26it/s]

(2, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 94%|█████████▎| 935/1000 [00:11<00:00, 82.41it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 95%|█████████▌| 953/1000 [00:11<00:00, 83.36it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 97%|█████████▋| 971/1000 [00:11<00:00, 82.97it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


 99%|█████████▉| 989/1000 [00:11<00:00, 83.83it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)


100%|██████████| 1000/1000 [00:12<00:00, 83.23it/s]

(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
(1, 5, 2)
Ignored images: 





In [6]:
# setup mtcnn
"""
mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True, keep_all=True,
    device=device
)
"""
mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.5, 0.7, 0.7], factor=0.709, post_process=True, keep_all=True,
    device=device
)

In [7]:
# define data reader

# no need to convert to PIL, because get_images already does that
# also, disabling image normalization for now

# note: horizontal flip must be disabled, or else mtcnn bbox labels would be invalidated

transform_train = transforms.Compose([
    #transforms.ToPILImage(),
    #transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])



transform_valid = transforms.Compose([
    #transforms.ToPILImage(),
    transforms.ToTensor(),
    #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

train_ds = utils.UTK_dataset(
    x_train, age_train, bbox_train, prob_train,
    land_train, trsfm=transform_train
)


train_loader = DataLoader(
    utils.UTK_dataset(
        x_train, age_train, bbox_train, prob_train,
        land_train, trsfm=transform_train
    ),
    batch_size=batch_size, num_workers=workers, shuffle=False # turn off shuffle for now
)


valid_loader = DataLoader(
    utils.UTK_dataset(
        x_valid, age_valid, bbox_valid, prob_valid,
        land_valid, trsfm=transform_valid
    ),
    batch_size=batch_size, num_workers=workers, shuffle=False
)


In [8]:
# This cell is for debugging the dataset / dataloader

import matplotlib.pyplot as plt
import torchvision
import PIL
"""
s = train_ds[0]
plt.imshow(torchvision.transforms.ToPILImage()(s[0]))
plt.show()

sys.path.insert(1, '/home/ubuntu/mtcnn/facenet_pytorch_c/models/utils')
import detect_face

img = detect_face.extract_face(torchvision.transforms.ToPILImage()(s[0]), s[2])
plt.imshow(torchvision.transforms.ToPILImage()(img))
plt.show()
""";

cnt = 0
for s in train_ds:
    cnt += 1
    print(len(s))
    if cnt > 10:
        break


5
5
5
5
5
5
5
5
5
5
5


In [9]:
# custom loss function, for training MTCNN

class LossFn:
    def __init__(self, cls_factor=1, box_factor=1, landmark_factor=1):
        # loss function
        self.cls_factor = cls_factor
        self.box_factor = box_factor
        self.land_factor = landmark_factor
        self.loss_cls = nn.BCELoss() # binary cross entropy
        self.loss_box = nn.MSELoss() # mean square error
        self.loss_landmark = nn.MSELoss()


    def cls_loss(self,gt_label,pred_label):
        pred_label = torch.squeeze(pred_label)
        gt_label = torch.squeeze(gt_label)
        # get the mask element which >= 0, only 0 and 1 can effect the detection loss
        mask = torch.ge(gt_label,0)
        valid_gt_label = torch.masked_select(gt_label,mask)
        valid_pred_label = torch.masked_select(pred_label,mask)
        return self.loss_cls(valid_pred_label,valid_gt_label)*self.cls_factor


    def box_loss(self,gt_label,gt_offset,pred_offset):
        pred_offset = torch.squeeze(pred_offset)
        gt_offset = torch.squeeze(gt_offset)
        gt_label = torch.squeeze(gt_label)

        #get the mask element which != 0
        unmask = torch.eq(gt_label,0)
        mask = torch.eq(unmask,0)
        #convert mask to dim index
        chose_index = torch.nonzero(mask.data)
        chose_index = torch.squeeze(chose_index)
        #only valid element can effect the loss
        valid_gt_offset = gt_offset[chose_index,:]
        valid_pred_offset = pred_offset[chose_index,:]
        return self.loss_box(valid_pred_offset,valid_gt_offset)*self.box_factor


    def landmark_loss(self,gt_label,gt_landmark,pred_landmark):
        pred_landmark = torch.squeeze(pred_landmark)
        gt_landmark = torch.squeeze(gt_landmark)
        gt_label = torch.squeeze(gt_label)
        mask = torch.eq(gt_label,-2)

        chose_index = torch.nonzero(mask.data)
        chose_index = torch.squeeze(chose_index)

        valid_gt_landmark = gt_landmark[chose_index, :]
        valid_pred_landmark = pred_landmark[chose_index, :]
        return self.loss_landmark(valid_pred_landmark,valid_gt_landmark)*self.land_factor

In [10]:
# train ONet

lossfn = LossFn()
mtcnn.onet.train(); # semicolon here, to suppress unnecessary pytorch output

optimizer = None

if opt == "Adam":
    print("Optimizer: Adam")
    optimizer = torch.optim.Adam(mtcnn.onet.parameters(), lr=learning_rate)
else:
    print("Error")

for epoch in range(1, 2):
    #
    for i, data in enumerate(train_loader):
        #inputs, ages, bboxs, landmarks = data
        _ = data
    #


Optimizer: Adam


RuntimeError: Caught RuntimeError in DataLoader worker process 1.
Original Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 1 and 2 in dimension 1 at /opt/conda/conda-bld/pytorch_1579040055865/work/aten/src/TH/generic/THTensor.cpp:612
