# Train

In [1]:
import os
import os.path as osp

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

from model import FaceMobileNet,myFaceMobileNet
from model.metric import ArcFace, CosFace
from model.loss import FocalLoss
from dataset import load_data
from config import Config as conf

import numpy as np
from PIL import Image
from ptflops import get_model_complexity_info

In [2]:
os.chdir("D:\Course\Graduate1\Data_Science\hw\hw5")

In [3]:
# Data Setup
dataloader, class_num = load_data(conf, training=True)
embedding_size = conf.embedding_size
device = conf.device

10574Total train images: 18325


In [4]:
# Network Setup
net = myFaceMobileNet(embedding_size).to(device)

if conf.metric == 'arcface':
    metric = ArcFace(embedding_size, class_num).to(device)
else:
    metric = CosFace(embedding_size, class_num).to(device)

net = nn.DataParallel(net)
metric = nn.DataParallel(metric)

In [5]:
# Training Setup
if conf.loss == 'focal_loss':
    criterion = FocalLoss(gamma=2)
else:
    criterion = nn.CrossEntropyLoss()

if conf.optimizer == 'sgd':
    optimizer = optim.SGD([{'params': net.parameters()}, {'params': metric.parameters()}], 
                            lr=conf.lr, weight_decay=conf.weight_decay)
else:
    optimizer = optim.Adam([{'params': net.parameters()}, {'params': metric.parameters()}],
                            lr=conf.lr, weight_decay=conf.weight_decay)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=conf.lr_step, gamma=0.1)

In [6]:
# Checkpoints Setup
os.makedirs(conf.checkpoints, exist_ok=True)

In [14]:
# Start training
net.train()
losses = []
#Test at the same time
images = unique_image(conf.test_list)
images = [osp.join(conf.test_root, img) for img in images]
groups = group_image(images, conf.test_batch_size)
feature_dict = dict()

for e in range(conf.epoch):
    #if e<=20:
    #    continue
    for data, labels in tqdm(dataloader, desc=f"Epoch {e}/{conf.epoch}",
                             ascii=True, total=len(dataloader)):
        data = data.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        embeddings = net(data)
        thetas = metric(embeddings, labels)
        loss = criterion(thetas, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch {e}/{conf.epoch}, Loss: {loss}")
    losses.append(loss)
    # Test
    net.eval()
    feature_dict = dict()
    for group in groups:
        d = featurize(group, conf.test_transform, net, conf.device,low_light = False)
        feature_dict.update(d) 
    accuracy, threshold = compute_accuracy(feature_dict, conf.test_list, conf.test_root) 

    print(
        f"Test Model: myFMN_{e}_sample{int(conf.train_sample_rate*100)}%.pth\n",
        f"Accuracy: {accuracy:.3f}\n"
        f"Threshold: {threshold:.3f}\n"
    )
    net.train()
    backbone_path = osp.join(conf.checkpoints, f"myFMN_{e}_sample{int(conf.train_sample_rate*100)}%.pth")
    torch.save(net.state_dict(), backbone_path)
    scheduler.step()

Epoch 0/30: 100%|##########| 287/287 [00:38<00:00,  7.53it/s]


Epoch 0/30, Loss: 0.09242450445890427
Test Model: myFMN_0_sample10%.pth
 Accuracy: 0.558
Threshold: 0.478



Epoch 1/30: 100%|##########| 287/287 [00:32<00:00,  8.89it/s]


Epoch 1/30, Loss: 1.6215300559997559
Test Model: myFMN_1_sample10%.pth
 Accuracy: 0.530
Threshold: 0.971



Epoch 2/30: 100%|##########| 287/287 [00:32<00:00,  8.90it/s]


Epoch 2/30, Loss: 0.010787849314510822
Test Model: myFMN_2_sample10%.pth
 Accuracy: 0.552
Threshold: 0.935



Epoch 3/30: 100%|##########| 287/287 [00:32<00:00,  8.87it/s]


Epoch 3/30, Loss: 7.4444990158081055
Test Model: myFMN_3_sample10%.pth
 Accuracy: 0.546
Threshold: 0.746



Epoch 4/30: 100%|##########| 287/287 [00:32<00:00,  8.87it/s]


Epoch 4/30, Loss: 5.121672438690439e-05
Test Model: myFMN_4_sample10%.pth
 Accuracy: 0.552
Threshold: 0.921



Epoch 5/30: 100%|##########| 287/287 [00:32<00:00,  8.88it/s]


Epoch 5/30, Loss: 4.027842044830322
Test Model: myFMN_5_sample10%.pth
 Accuracy: 0.539
Threshold: 0.992



Epoch 6/30: 100%|##########| 287/287 [00:32<00:00,  8.86it/s]


Epoch 6/30, Loss: 1.250727927981643e-05
Test Model: myFMN_6_sample10%.pth
 Accuracy: 0.532
Threshold: 0.957



Epoch 7/30: 100%|##########| 287/287 [00:32<00:00,  8.87it/s]


Epoch 7/30, Loss: 2.4241514205932617
Test Model: myFMN_7_sample10%.pth
 Accuracy: 0.539
Threshold: 0.918



Epoch 8/30: 100%|##########| 287/287 [00:32<00:00,  8.82it/s]


Epoch 8/30, Loss: 0.5672804117202759
Test Model: myFMN_8_sample10%.pth
 Accuracy: 0.547
Threshold: 0.808



Epoch 9/30: 100%|##########| 287/287 [00:32<00:00,  8.82it/s]


Epoch 9/30, Loss: 4.12794303894043
Test Model: myFMN_9_sample10%.pth
 Accuracy: 0.534
Threshold: 0.958



Epoch 10/30: 100%|##########| 287/287 [00:32<00:00,  8.90it/s]


Epoch 10/30, Loss: 0.4781588315963745
Test Model: myFMN_10_sample10%.pth
 Accuracy: 0.537
Threshold: 0.773



Epoch 11/30: 100%|##########| 287/287 [00:32<00:00,  8.83it/s]


Epoch 11/30, Loss: 3.604029417037964
Test Model: myFMN_11_sample10%.pth
 Accuracy: 0.539
Threshold: 0.803



Epoch 12/30: 100%|##########| 287/287 [00:32<00:00,  8.84it/s]


Epoch 12/30, Loss: 0.17157980799674988
Test Model: myFMN_12_sample10%.pth
 Accuracy: 0.541
Threshold: 0.838



Epoch 13/30: 100%|##########| 287/287 [00:32<00:00,  8.80it/s]


Epoch 13/30, Loss: 7.0279541015625
Test Model: myFMN_13_sample10%.pth
 Accuracy: 0.536
Threshold: 0.810



Epoch 14/30: 100%|##########| 287/287 [00:32<00:00,  8.87it/s]


Epoch 14/30, Loss: 0.20764197409152985
Test Model: myFMN_14_sample10%.pth
 Accuracy: 0.531
Threshold: 0.760



Epoch 15/30: 100%|##########| 287/287 [00:32<00:00,  8.81it/s]


Epoch 15/30, Loss: 0.2903079688549042
Test Model: myFMN_15_sample10%.pth
 Accuracy: 0.535
Threshold: 0.818



Epoch 16/30: 100%|##########| 287/287 [00:32<00:00,  8.86it/s]


Epoch 16/30, Loss: 0.13993732631206512
Test Model: myFMN_16_sample10%.pth
 Accuracy: 0.540
Threshold: 0.869



Epoch 17/30: 100%|##########| 287/287 [00:32<00:00,  8.80it/s]


Epoch 17/30, Loss: 0.4362316429615021
Test Model: myFMN_17_sample10%.pth
 Accuracy: 0.538
Threshold: 0.847



Epoch 18/30: 100%|##########| 287/287 [00:32<00:00,  8.82it/s]


Epoch 18/30, Loss: 1.659886765992269e-05
Test Model: myFMN_18_sample10%.pth
 Accuracy: 0.537
Threshold: 0.891



Epoch 19/30: 100%|##########| 287/287 [00:32<00:00,  8.80it/s]


Epoch 19/30, Loss: 0.19764065742492676
Test Model: myFMN_19_sample10%.pth
 Accuracy: 0.548
Threshold: 0.847



Epoch 20/30: 100%|##########| 287/287 [00:32<00:00,  8.83it/s]


Epoch 20/30, Loss: 3.047953032364603e-05
Test Model: myFMN_20_sample10%.pth
 Accuracy: 0.541
Threshold: 0.886



Epoch 21/30: 100%|##########| 287/287 [00:32<00:00,  8.80it/s]


Epoch 21/30, Loss: 0.15045003592967987
Test Model: myFMN_21_sample10%.pth
 Accuracy: 0.545
Threshold: 0.860



Epoch 22/30: 100%|##########| 287/287 [00:32<00:00,  8.84it/s]


Epoch 22/30, Loss: 0.00010764157923404127
Test Model: myFMN_22_sample10%.pth
 Accuracy: 0.547
Threshold: 0.830



Epoch 23/30: 100%|##########| 287/287 [00:32<00:00,  8.86it/s]


Epoch 23/30, Loss: 0.13783210515975952
Test Model: myFMN_23_sample10%.pth
 Accuracy: 0.546
Threshold: 0.851



Epoch 24/30: 100%|##########| 287/287 [00:32<00:00,  8.78it/s]


Epoch 24/30, Loss: 0.12723605334758759
Test Model: myFMN_24_sample10%.pth
 Accuracy: 0.541
Threshold: 0.845



Epoch 25/30: 100%|##########| 287/287 [00:32<00:00,  8.80it/s]


Epoch 25/30, Loss: 0.0011353755835443735
Test Model: myFMN_25_sample10%.pth
 Accuracy: 0.539
Threshold: 0.873



Epoch 26/30: 100%|##########| 287/287 [00:32<00:00,  8.83it/s]


Epoch 26/30, Loss: 1.2327559488767292e-05
Test Model: myFMN_26_sample10%.pth
 Accuracy: 0.542
Threshold: 0.873



Epoch 27/30: 100%|##########| 287/287 [00:32<00:00,  8.79it/s]


Epoch 27/30, Loss: 0.20878396928310394
Test Model: myFMN_27_sample10%.pth
 Accuracy: 0.538
Threshold: 0.915



Epoch 28/30: 100%|##########| 287/287 [00:32<00:00,  8.84it/s]


Epoch 28/30, Loss: 0.17809227108955383
Test Model: myFMN_28_sample10%.pth
 Accuracy: 0.543
Threshold: 0.784



Epoch 29/30: 100%|##########| 287/287 [00:32<00:00,  8.79it/s]


Epoch 29/30, Loss: 0.11812622100114822
Test Model: myFMN_29_sample10%.pth
 Accuracy: 0.541
Threshold: 0.861



# Test

In [7]:
def unique_image(pair_list) -> set:
    """Return unique image path in pair_list.txt"""
    with open(pair_list, 'r') as fd:
        pairs = fd.readlines()
    unique = set()
    for pair in pairs:
        id1, id2, _ = pair.split()
        unique.add(id1)
        unique.add(id2)
    return unique

In [8]:
def group_image(images: set, batch) -> list:
    """Group image paths by batch size"""
    images = list(images)
    size = len(images)
    res = []
    for i in range(0, size, batch):
        end = min(batch + i, size)
        res.append(images[i : end])
    return res

In [9]:
def _preprocess(images: list, transform, low_light = False) -> torch.Tensor:
    res = []
    for img in images:
        im = Image.open(img)
        if low_light:
            im = (np.sqrt (im)*2).astype(np.uint8)
            im = Image.fromarray(im)
        im = transform(im)
        res.append(im)
    data = torch.cat(res, dim=0)  # shape: (batch, 128, 128)
    data = data[:, None, :, :]    # shape: (batch, 1, 128, 128)
    return data

In [10]:
def featurize(images: list, transform, net, device,low_light = False) -> dict:
    """featurize each image and save into a dictionary
    Args:
        images: image paths
        transform: test transform
        net: pretrained model
        device: cpu or cuda
    Returns:
        Dict (key: imagePath, value: feature)
    """
    data = _preprocess(images, transform,low_light = low_light)
    data = data.to(device)
    net = net.to(device)
    with torch.no_grad():
        features = net(data) 
    res = {img: feature for (img, feature) in zip(images, features)}
    return res

In [11]:
def cosin_metric(x1, x2):
    return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))

In [12]:
def threshold_search(y_score, y_true):
    y_score = np.asarray(y_score)
    y_true = np.asarray(y_true)
    best_acc = 0
    best_th = 0
    for i in range(len(y_score)):
        th = y_score[i]
        y_test = (y_score >= th)
        acc = np.mean((y_test == y_true).astype(int))
        if acc > best_acc:
            best_acc = acc
            best_th = th
    return best_acc, best_th

In [13]:
def compute_accuracy(feature_dict, pair_list, test_root):
    with open(pair_list, 'r') as f:
        pairs = f.readlines()

    similarities = []
    labels = []
    for pair in pairs:
        img1, img2, label = pair.split()
        img1 = osp.join(test_root, img1)
        img2 = osp.join(test_root, img2)
        feature1 = feature_dict[img1].cpu().numpy()
        feature2 = feature_dict[img2].cpu().numpy()
        label = int(label)

        similarity = cosin_metric(feature1, feature2)
        similarities.append(similarity)
        labels.append(label)

    accuracy, threshold = threshold_search(similarities, labels)
    return accuracy, threshold

In [16]:
#for i in range(30):
test_model = "checkpoints/myFMN_"+str(19)+"_sample25%.pth" #conf.test_model
model = myFaceMobileNet(conf.embedding_size)
model = nn.DataParallel(model)
model.load_state_dict(torch.load(test_model, map_location=conf.device))
model.eval()

images = unique_image(conf.test_list)
images = [osp.join(conf.test_root, img) for img in images]
groups = group_image(images, conf.test_batch_size)

feature_dict = dict()
for group in groups:
    d = featurize(group, conf.test_transform, model, conf.device,low_light = False)
    feature_dict.update(d) 
accuracy, threshold = compute_accuracy(feature_dict, conf.test_list, conf.test_root) 

print(
    f"Test Model: {test_model}\n"
    f"Accuracy: {accuracy:.3f}\n"
    f"Threshold: {threshold:.3f}\n"
)

Test Model: checkpoints/myFMN_19_sample25%.pth
Accuracy: 0.558
Threshold: 0.772



# Get #params and FLOPs

In [15]:
with torch.cuda.device(0):
    #net = model
    model = net
    macs, params = get_model_complexity_info(net, (1,128,128), as_strings=False,
                                           print_per_layer_stat=False, verbose=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', macs*2))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))

Computational complexity:       557627392.0
Number of parameters:           1916096 


In [2]:
from model import myFaceMobileNet

In [3]:
model = myFaceMobileNet(conf.embedding_size)
model = nn.DataParallel(model)
#model.load_state_dict(torch.load(test_model, map_location=conf.device))
#model.eval()
