In [None]:
# facenet_pytorch_c: avoid confusion with system default facenet_pytorch
#from facenet_pytorch_c import MTCNN

from tqdm import tqdm
import numpy as np
import os

# pytorch
import torch
import torch.optim as optim
from torch import nn

# data handling
from torch.utils.data import DataLoader

# torchvision libs
import torchvision
from torchvision import datasets
from torchvision import transforms
import PIL

import utils_pnet as utils


In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

In [None]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Available device: " + str(device))

# training hyperparameters
learning_rate = 1e-3
epochs = 200
decay_step = [100, 150]
decay_rate = 0.1
opt = 'Adam'    # either Adam or SGD
batch_size = 64

# data loading parameters
workers = 4

In [None]:
x_t, b_prob_t, b_box_t, _, _, x_v, b_prob_v, b_box_v, _, _ = utils.get_images(
    img_path='/home/ubuntu/db_proc/db/images',
    anno_path='/home/ubuntu/db_proc/db/annotations',
    valid_percent=0.1, resize_shape=(48,48),
    add_orig=False, gen_xtra_neg=True, gen_xtra_body=True, gen_xtra_face=False
)

print(len(x_t))
print_freq = int(len(x_t)/batch_size - 5)
print("print freq: {}".format(print_freq))

In [None]:

for i in range(0, 300, 20):
    plt.imshow(x_t[i])
    plt.show()
    print("body: {}".format(b_box_t[i]))

In [None]:
transform_train = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

transform_valid = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

train_ds = utils.CIVDS_bnet(x_t, b_prob_t, b_box_t, trsfm=transform_train)

valid_ds = utils.CIVDS_bnet(x_v, b_prob_v, b_box_v, trsfm=transform_valid)

train_loader = DataLoader(
    train_ds, batch_size=batch_size,
    num_workers=workers, shuffle=True
)

valid_loader = DataLoader(
    valid_ds, batch_size=batch_size,
    num_workers=workers, shuffle=False
)

In [None]:
class BNet(nn.Module):

    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
        self.prelu1 = nn.PReLU(32)
        self.pool1 = nn.MaxPool2d(3, 2, ceil_mode=True)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.prelu2 = nn.PReLU(64)
        self.pool2 = nn.MaxPool2d(3, 2, ceil_mode=True)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3)
        self.prelu3 = nn.PReLU(64)
        self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=2)
        self.prelu4 = nn.PReLU(128)
        self.dense5 = nn.Linear(1152, 256)
        self.prelu5 = nn.PReLU(256)
        
        self.dense6_1 = nn.Linear(256, 2)    # body prob
        self.softmax6_1 = nn.Softmax(dim=1)
        
        self.dense6_2 = nn.Linear(256, 4)    # body bbox

        self.training = True

    def forward(self, x):
        x = self.conv1(x)
        x = self.prelu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.prelu2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.prelu3(x)
        x = self.pool3(x)
        x = self.conv4(x)
        x = self.prelu4(x)
        x = self.dense5(x.view(x.shape[0], -1))
        x = self.prelu5(x)
        a = self.dense6_1(x)
        a = self.softmax6_1(a)
        b = self.dense6_2(x)
        return a, b


In [None]:
bnet = BNet()
bnet.train()
bnet.to(device)

In [None]:

from tensorboardX import SummaryWriter

writer = SummaryWriter(log_dir="/home/ubuntu/tensorLog") # tensorboard writer

prob_lossfn = nn.BCELoss().to(device)
bbox_lossfn = nn.MSELoss().to(device)

optimizer = None

if opt == "Adam":
    print("Optimizer: Adam")
    optimizer = torch.optim.Adam(bnet.parameters(), lr=learning_rate, amsgrad=True)
elif opt == "SGD":
    print("Optimizer: SGD")
    optimizer = torch.optim.SGD(bnet.parameters(), lr=learning_rate, momentum=0.9)
else:
    print("Error")

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=decay_step, gamma=decay_rate)

rl1, rl2 = 0, 0

for epoch in range(1, epochs+1):
    
    for batch_idx, data in enumerate(train_loader):
        
        im, b_prob, b_box = data
        
        im = im.to(device)
        b_prob = b_prob.float().to(device)
        b_box = b_box.float().to(device)
        
        o_b_prob, o_b_box = bnet(im)
        
        o_b_prob = o_b_prob.float().squeeze()
        o_b_box = o_b_box.float().squeeze()
        
        b_prob_l = prob_lossfn(o_b_prob, b_prob)
        b_box_l = bbox_lossfn(o_b_box, b_box)
        
        rl1 += b_prob_l.item()
        rl2 += b_box_l.item()

        all_loss = b_prob_l + b_box_l
        
        if batch_idx % print_freq == print_freq-1:
            
            print(
                "ep: {}; bpl: {:.2f}; bbl: {:.2f};".format(
                    epoch, rl1/print_freq, rl2/print_freq
                )
            )
            writer.add_scalar('bpl', rl1/print_freq, epoch)
            writer.add_scalar('bbl', rl2/print_freq, epoch)
            
            rl1, rl2 = 0, 0

        optimizer.zero_grad()
        all_loss.backward()
        optimizer.step()
    
    scheduler.step()

print("finished training")

save_name = 'bnet.pt'
torch.save(bnet.state_dict(), save_name)
print('Saved model at {}'.format(save_name))

In [None]:
total = len(valid_ds)
b_err = 0
bt, bp = [], []

for idx in range(0, total):
    
    obp, obb = bnet(valid_ds[idx][0].unsqueeze(0).to(device))
    obp = obp.squeeze()
    
    aop = obp[0]
    obp = int(obp[0] > obp[1])

    tbp = int(b_prob_v[idx][0] > b_prob_v[idx][1])

    b_err += (obp != tbp)
        
    if idx % 20 == 0:
        plt.imshow(transforms.ToPILImage()(valid_ds[idx][0]))
        plt.show()
        print("tb: {}; pb: {:.1f}".format(int(tbp), aop.cpu().detach().item()))
        print("body: " + str(np.rint(obb.cpu().detach().squeeze().numpy())))

print("total: {}".format(total))
print("body accuracy: {:.1f}%".format(100*(total-b_err)/total))

"""
p, r, f1 = utils.f1_score(truth, pred, 0)
print("precision: {:.2f}, recall: {:.2f}, f1: {:.2f}".format(p, r, f1))
print("valid_ds length: {}".format(len(valid_ds)))
print("age dist: ", end='')
""";

In [None]:
save_name = 'bnet.pt'
torch.save(bnet.state_dict(), save_name)
print('Saved model at {}'.format(save_name))