In [1]:
# set up path to facenet_pytorch_c
import sys
sys.path.insert(1, '/home/ubuntu/mtcnn')

In [2]:
# facenet_pytorch_c: avoid confusion with system default facenet_pytorch
from facenet_pytorch_c import MTCNN

from tqdm import tqdm
import numpy as np
import os

# pytorch
import torch
import torch.optim as optim
from torch import nn

# data handling
from torch.utils.data import DataLoader

# torchvision libs
from torchvision import datasets
from torchvision import transforms

# other custom scripts
import utils

In [3]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Available device: " + str(device))

# training hyperparameters
learning_rate = 1e-3
epochs = 30
decay_step = [15]
decay_rate = 0.1
opt = 'Adam'    # either Adam or SGD
batch_size = 16


Available device: cuda:0


In [4]:
# data loading parameters
workers = 4
resize_shape = (48, 48)

In [5]:
# get data
x_train, age_train, fn_train, bbox_train, prob_train, land_train, x_valid, age_valid, fn_valid, bbox_valid, prob_valid, land_valid = utils.get_images(
    r'/home/ubuntu/UTKFace', resize_shape=resize_shape, label_mode=False
)

100%|██████████| 500/500 [00:03<00:00, 158.40it/s]

Ignored images: 





In [6]:
# setup mtcnn
"""
mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True, keep_all=True,
    device=device
)
"""
mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.5, 0.7, 0.7], factor=0.709, post_process=True, keep_all=True,
    device=device
)

In [7]:
# define data reader

# no need to convert to PIL, because get_images already does that
# also, disabling image normalization for now

# note: horizontal flip must be disabled, or else mtcnn bbox labels would be invalidated

transform_train = transforms.Compose([
    #transforms.ToPILImage(),
    #transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])



transform_valid = transforms.Compose([
    #transforms.ToPILImage(),
    transforms.ToTensor(),
    #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

train_ds = utils.UTK_dataset(
    x_train, age_train, bbox_train, prob_train,
    land_train, trsfm=transform_train
)


train_loader = DataLoader(
    utils.UTK_dataset(
        x_train, age_train, bbox_train, prob_train,
        land_train, trsfm=transform_train
    ),
    batch_size=batch_size, num_workers=workers, shuffle=False # turn off shuffle for now
)


valid_loader = DataLoader(
    utils.UTK_dataset(
        x_valid, age_valid, bbox_valid, prob_valid,
        land_valid, trsfm=transform_valid
    ),
    batch_size=batch_size, num_workers=workers, shuffle=False
)


In [None]:
# This cell is for debugging the dataset / dataloader

import matplotlib.pyplot as plt
import torchvision
import PIL
"""
s = train_ds[0]
plt.imshow(torchvision.transforms.ToPILImage()(s[0]))
plt.show()

sys.path.insert(1, '/home/ubuntu/mtcnn/facenet_pytorch_c/models/utils')
import detect_face

img = detect_face.extract_face(torchvision.transforms.ToPILImage()(s[0]), s[2])
plt.imshow(torchvision.transforms.ToPILImage()(img))
plt.show()
""";

cnt = 0
for s in train_ds:
    cnt += 1
    print(len(s))
    if cnt > 10:
        break


In [8]:
# custom loss function, for training MTCNN

class LossFn:
    def __init__(self, cls_factor=1, box_factor=1, landmark_factor=1):
        # loss function
        self.cls_factor = cls_factor
        self.box_factor = box_factor
        self.land_factor = landmark_factor
        self.loss_cls = nn.BCELoss() # binary cross entropy
        self.loss_box = nn.MSELoss() # mean square error
        self.loss_landmark = nn.MSELoss()


    def cls_loss(self,gt_label,pred_label):
        pred_label = torch.squeeze(pred_label)
        gt_label = torch.squeeze(gt_label)
        # get the mask element which >= 0, only 0 and 1 can effect the detection loss
        mask = torch.ge(gt_label,0)
        valid_gt_label = torch.masked_select(gt_label,mask)
        valid_pred_label = torch.masked_select(pred_label,mask)
        return self.loss_cls(valid_pred_label,valid_gt_label)*self.cls_factor


    def box_loss(self,gt_label,gt_offset,pred_offset):
        pred_offset = torch.squeeze(pred_offset)
        gt_offset = torch.squeeze(gt_offset)
        gt_label = torch.squeeze(gt_label)

        #get the mask element which != 0
        unmask = torch.eq(gt_label,0)
        mask = torch.eq(unmask,0)
        #convert mask to dim index
        chose_index = torch.nonzero(mask.data)
        chose_index = torch.squeeze(chose_index)
        #only valid element can effect the loss
        valid_gt_offset = gt_offset[chose_index,:]
        valid_pred_offset = pred_offset[chose_index,:]
        return self.loss_box(valid_pred_offset,valid_gt_offset)*self.box_factor


    def landmark_loss(self,gt_label,gt_landmark,pred_landmark):
        pred_landmark = torch.squeeze(pred_landmark)
        gt_landmark = torch.squeeze(gt_landmark)
        gt_label = torch.squeeze(gt_label)
        mask = torch.eq(gt_label,-2)

        chose_index = torch.nonzero(mask.data)
        chose_index = torch.squeeze(chose_index)

        valid_gt_landmark = gt_landmark[chose_index, :]
        valid_pred_landmark = pred_landmark[chose_index, :]
        return self.loss_landmark(valid_pred_landmark,valid_gt_landmark)*self.land_factor

In [11]:
# train ONet

lossfn = LossFn()
mtcnn.train(); # semicolon here, to suppress unnecessary pytorch output
mtcnn.onet.train()
mtcnn.onet.to(device)


optimizer = None

if opt == "Adam":
    print("Optimizer: Adam")
    optimizer = torch.optim.Adam(mtcnn.onet.parameters(), lr=learning_rate)
else:
    print("Error")

for epoch in range(1, 2):

    for i, data in enumerate(train_loader):
        im, age, bbox, prob, landmarks = data
    
        im = im.to(device)
        age = age.to(device)
        bbox = bbox.to(device)
        prob = prob.to(device)
        landmarks = landmarks.to(device)
        
        o_bbox, o_landmarks, o_prob, o_age = mtcnn.onet(im)
        """
        print(o_bbox.shape)
        print(o_landmarks.shape)
        print(o_prob.shape)
        
        print(bbox.shape)
        print(landmarks.shape)
        print(prob.shape)
        """
        
        print(o_landmarks)
        print(landmarks)
        
        break
        
    #


Optimizer: Adam
tensor([[0.2867, 0.6825, 0.5358, 0.3352, 0.6444, 0.2748, 0.2872, 0.4810, 0.7125,
         0.7282],
        [0.3181, 0.7159, 0.4587, 0.3045, 0.7156, 0.3160, 0.3092, 0.5278, 0.6626,
         0.6653],
        [0.3201, 0.7094, 0.4553, 0.3746, 0.7189, 0.2969, 0.2873, 0.5432, 0.6868,
         0.6748],
        [0.2774, 0.6876, 0.4685, 0.2811, 0.6524, 0.2988, 0.3079, 0.5609, 0.6659,
         0.6841],
        [0.2907, 0.7053, 0.4893, 0.2944, 0.6945, 0.2841, 0.2830, 0.5299, 0.6716,
         0.6795],
        [0.3101, 0.7216, 0.4147, 0.3273, 0.6708, 0.2974, 0.2767, 0.5101, 0.7096,
         0.7043],
        [0.3105, 0.7048, 0.4779, 0.3301, 0.7015, 0.2960, 0.2968, 0.5681, 0.6970,
         0.7013],
        [0.2904, 0.6717, 0.5538, 0.2956, 0.6334, 0.2890, 0.3056, 0.4999, 0.6983,
         0.7176],
        [0.2680, 0.6965, 0.5116, 0.3206, 0.6603, 0.2963, 0.2983, 0.4730, 0.7131,
         0.7261],
        [0.2535, 0.6678, 0.5304, 0.2692, 0.6210, 0.2981, 0.3129, 0.5240, 0.6925,
         0.7