# Train

In [4]:
import os
import os.path as osp

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

from model import FaceMobileNet
from model.metric import ArcFace, CosFace
from model.loss import FocalLoss
from dataset import load_data
from config import Config as conf

import numpy as np
from PIL import Image
from ptflops import get_model_complexity_info

In [2]:
os.chdir("D:\Course\Graduate1\Data_Science\hw\hw5")

In [3]:
# Data Setup
dataloader, class_num = load_data(conf, training=True)
embedding_size = conf.embedding_size
device = conf.device

10574Total train images: 18325


In [4]:
# Network Setup
net = FaceMobileNet(embedding_size).to(device)

if conf.metric == 'arcface':
    metric = ArcFace(embedding_size, class_num).to(device)
else:
    metric = CosFace(embedding_size, class_num).to(device)

net = nn.DataParallel(net)
metric = nn.DataParallel(metric)

In [5]:
# Training Setup
if conf.loss == 'focal_loss':
    criterion = FocalLoss(gamma=2)
else:
    criterion = nn.CrossEntropyLoss()

if conf.optimizer == 'sgd':
    optimizer = optim.SGD([{'params': net.parameters()}, {'params': metric.parameters()}], 
                            lr=conf.lr, weight_decay=conf.weight_decay)
else:
    optimizer = optim.Adam([{'params': net.parameters()}, {'params': metric.parameters()}],
                            lr=conf.lr, weight_decay=conf.weight_decay)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=conf.lr_step, gamma=0.1)

In [6]:
# Checkpoints Setup
os.makedirs(conf.checkpoints, exist_ok=True)

In [7]:
# Start training
net.train()

for e in range(conf.epoch):
    for data, labels in tqdm(dataloader, desc=f"Epoch {e}/{conf.epoch}",
                             ascii=True, total=len(dataloader)):
        data = data.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        embeddings = net(data)
        thetas = metric(embeddings, labels)
        loss = criterion(thetas, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch {e}/{conf.epoch}, Loss: {loss}")

    backbone_path = osp.join(conf.checkpoints, f"{e}_sample{int(conf.train_sample_rate*100)}%.pth")
    torch.save(net.state_dict(), backbone_path)
    scheduler.step()

Epoch 0/30: 100%|##########| 4/4 [00:12<00:00,  3.04s/it]


Epoch 0/30, Loss: 20.30253791809082


Epoch 1/30: 100%|##########| 4/4 [00:06<00:00,  1.56s/it]


Epoch 1/30, Loss: 8.592578887939453


Epoch 2/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]


Epoch 2/30, Loss: 3.249884605407715


Epoch 3/30: 100%|##########| 4/4 [00:06<00:00,  1.56s/it]


Epoch 3/30, Loss: 0.34525182843208313


Epoch 4/30: 100%|##########| 4/4 [00:06<00:00,  1.54s/it]


Epoch 4/30, Loss: 5.512691497802734


Epoch 5/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]


Epoch 5/30, Loss: 5.4474382400512695


Epoch 6/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]


Epoch 6/30, Loss: 2.138883590698242


Epoch 7/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]


Epoch 7/30, Loss: 1.4610414505004883


Epoch 8/30: 100%|##########| 4/4 [00:06<00:00,  1.59s/it]


Epoch 8/30, Loss: 2.77401065826416


Epoch 9/30: 100%|##########| 4/4 [00:06<00:00,  1.57s/it]


Epoch 9/30, Loss: 5.009413719177246


Epoch 10/30: 100%|##########| 4/4 [00:06<00:00,  1.53s/it]


Epoch 10/30, Loss: 2.253570079803467


Epoch 11/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]


Epoch 11/30, Loss: 3.1799182891845703


Epoch 12/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]


Epoch 12/30, Loss: 3.4078612327575684


Epoch 13/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]


Epoch 13/30, Loss: 4.071030616760254


Epoch 14/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]


Epoch 14/30, Loss: 2.1476693153381348


Epoch 15/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]


Epoch 15/30, Loss: 4.304543495178223


Epoch 16/30: 100%|##########| 4/4 [00:06<00:00,  1.54s/it]


Epoch 16/30, Loss: 1.6566944122314453


Epoch 17/30: 100%|##########| 4/4 [00:06<00:00,  1.59s/it]


Epoch 17/30, Loss: 2.985196590423584


Epoch 18/30: 100%|##########| 4/4 [00:06<00:00,  1.50s/it]


Epoch 18/30, Loss: 4.5303730964660645


Epoch 19/30: 100%|##########| 4/4 [00:05<00:00,  1.50s/it]


Epoch 19/30, Loss: 1.1936410665512085


Epoch 20/30: 100%|##########| 4/4 [00:06<00:00,  1.54s/it]


Epoch 20/30, Loss: 3.9509010314941406


Epoch 21/30: 100%|##########| 4/4 [00:06<00:00,  1.54s/it]


Epoch 21/30, Loss: 1.7074648141860962


Epoch 22/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]


Epoch 22/30, Loss: 0.9064067006111145


Epoch 23/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]


Epoch 23/30, Loss: 0.9099773168563843


Epoch 24/30: 100%|##########| 4/4 [00:06<00:00,  1.59s/it]


Epoch 24/30, Loss: 1.4772956371307373


Epoch 25/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]


Epoch 25/30, Loss: 1.4431569576263428


Epoch 26/30: 100%|##########| 4/4 [00:05<00:00,  1.50s/it]


Epoch 26/30, Loss: 1.6583263874053955


Epoch 27/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]


Epoch 27/30, Loss: 1.4723771810531616


Epoch 28/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]


Epoch 28/30, Loss: 1.989856481552124


Epoch 29/30: 100%|##########| 4/4 [00:06<00:00,  1.55s/it]

Epoch 29/30, Loss: 0.6911956071853638





# Test

In [8]:
def unique_image(pair_list) -> set:
    """Return unique image path in pair_list.txt"""
    with open(pair_list, 'r') as fd:
        pairs = fd.readlines()
    unique = set()
    for pair in pairs:
        id1, id2, _ = pair.split()
        unique.add(id1)
        unique.add(id2)
    return unique

In [9]:
def group_image(images: set, batch) -> list:
    """Group image paths by batch size"""
    images = list(images)
    size = len(images)
    res = []
    for i in range(0, size, batch):
        end = min(batch + i, size)
        res.append(images[i : end])
    return res

In [10]:
def _preprocess(images: list, transform, low_light = False) -> torch.Tensor:
    res = []
    for img in images:
        im = Image.open(img)
        if low_light:
            im = (np.sqrt (im)*2).astype(np.uint8)
            im = Image.fromarray(im)
        im = transform(im)
        res.append(im)
    data = torch.cat(res, dim=0)  # shape: (batch, 128, 128)
    data = data[:, None, :, :]    # shape: (batch, 1, 128, 128)
    return data

In [11]:
def featurize(images: list, transform, net, device,low_light = False) -> dict:
    """featurize each image and save into a dictionary
    Args:
        images: image paths
        transform: test transform
        net: pretrained model
        device: cpu or cuda
    Returns:
        Dict (key: imagePath, value: feature)
    """
    data = _preprocess(images, transform,low_light = low_light)
    data = data.to(device)
    net = net.to(device)
    with torch.no_grad():
        features = net(data) 
    res = {img: feature for (img, feature) in zip(images, features)}
    return res

In [12]:
def cosin_metric(x1, x2):
    return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))

In [13]:
def threshold_search(y_score, y_true):
    y_score = np.asarray(y_score)
    y_true = np.asarray(y_true)
    best_acc = 0
    best_th = 0
    for i in range(len(y_score)):
        th = y_score[i]
        y_test = (y_score >= th)
        acc = np.mean((y_test == y_true).astype(int))
        if acc > best_acc:
            best_acc = acc
            best_th = th
    return best_acc, best_th

In [14]:
def compute_accuracy(feature_dict, pair_list, test_root):
    with open(pair_list, 'r') as f:
        pairs = f.readlines()

    similarities = []
    labels = []
    for pair in pairs:
        img1, img2, label = pair.split()
        img1 = osp.join(test_root, img1)
        img2 = osp.join(test_root, img2)
        feature1 = feature_dict[img1].cpu().numpy()
        feature2 = feature_dict[img2].cpu().numpy()
        label = int(label)

        similarity = cosin_metric(feature1, feature2)
        similarities.append(similarity)
        labels.append(label)

    accuracy, threshold = threshold_search(similarities, labels)
    return accuracy, threshold

In [16]:
test_model = "checkpoints/28_sample10%.pth" #conf.test_model
model = FaceMobileNet(conf.embedding_size)
model = nn.DataParallel(model)
model.load_state_dict(torch.load(test_model, map_location=conf.device))
model.eval()

images = unique_image(conf.test_list)
images = [osp.join(conf.test_root, img) for img in images]
groups = group_image(images, conf.test_batch_size)

feature_dict = dict()
for group in groups:
    d = featurize(group, conf.test_transform, model, conf.device,low_light = True)
    feature_dict.update(d) 
accuracy, threshold = compute_accuracy(feature_dict, conf.test_list, conf.test_root) 

print(
    f"Test Model: {test_model}\n"
    f"Accuracy: {accuracy:.3f}\n"
    f"Threshold: {threshold:.3f}\n"
)

Test Model: checkpoints/28_sample10%.pth
Accuracy: 0.545
Threshold: 0.999



# Get #params and FLOPs

In [31]:
with torch.cuda.device(0):
    net = model
    macs, params = get_model_complexity_info(net, (1,128,128), as_strings=False,
                                           print_per_layer_stat=False, verbose=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', macs*2))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))

Computational complexity:       593278976.0
Number of parameters:           1985728 


In [1]:
from model import myFaceMobileNet

ImportError: cannot import name 'myFaceMobileNet' from 'model' (D:\Course\Graduate1\Data_Science\hw\hw5\model\__init__.py)

In [None]:
model = FaceMobileNet(conf.embedding_size)
model = nn.DataParallel(model)
model.load_state_dict(torch.load(test_model, map_location=conf.device))
model.eval()
