# Train

In [1]:
import os
import os.path as osp

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

from model import FaceMobileNet
from model.metric import ArcFace, CosFace
from model.loss import FocalLoss
from dataset import load_data
from config import Config as conf

import numpy as np
from PIL import Image

In [2]:
from dataset import load_data

In [3]:
load_data(conf)

NameError: name 'np' is not defined

In [2]:
os.chdir("D:\Course\Graduate1\Data_Science\hw\hw5")

In [3]:
# Data Setup
dataloader, class_num = load_data(conf, training=True)
embedding_size = conf.embedding_size
device = conf.device

In [4]:
# Network Setup
net = FaceMobileNet(embedding_size).to(device)

if conf.metric == 'arcface':
    metric = ArcFace(embedding_size, class_num).to(device)
else:
    metric = CosFace(embedding_size, class_num).to(device)

net = nn.DataParallel(net)
metric = nn.DataParallel(metric)

In [5]:
# Training Setup
if conf.loss == 'focal_loss':
    criterion = FocalLoss(gamma=2)
else:
    criterion = nn.CrossEntropyLoss()

if conf.optimizer == 'sgd':
    optimizer = optim.SGD([{'params': net.parameters()}, {'params': metric.parameters()}], 
                            lr=conf.lr, weight_decay=conf.weight_decay)
else:
    optimizer = optim.Adam([{'params': net.parameters()}, {'params': metric.parameters()}],
                            lr=conf.lr, weight_decay=conf.weight_decay)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=conf.lr_step, gamma=0.1)

In [6]:
# Checkpoints Setup
os.makedirs(conf.checkpoints, exist_ok=True)

In [12]:
# Start training
net.train()

for e in range(conf.epoch):
    for data, labels in tqdm(dataloader, desc=f"Epoch {e}/{conf.epoch}",
                             ascii=True, total=len(dataloader)):
        data = data.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        embeddings = net(data)
        thetas = metric(embeddings, labels)
        loss = criterion(thetas, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch {e}/{conf.epoch}, Loss: {loss}")

    backbone_path = osp.join(conf.checkpoints, f"{e}.pth")
    torch.save(net.state_dict(), backbone_path)
    scheduler.step()

Epoch 0/30: 100%|##########| 7119/7119 [14:31<00:00,  8.17it/s]


Epoch 0/30, Loss: 14.636209487915039


Epoch 1/30: 100%|##########| 7119/7119 [14:30<00:00,  8.18it/s]


Epoch 1/30, Loss: 13.404420852661133


Epoch 2/30: 100%|##########| 7119/7119 [14:30<00:00,  8.18it/s]


Epoch 2/30, Loss: 12.529593467712402


Epoch 3/30: 100%|##########| 7119/7119 [14:30<00:00,  8.18it/s]


Epoch 3/30, Loss: 11.526164054870605


Epoch 4/30: 100%|##########| 7119/7119 [14:30<00:00,  8.18it/s]


Epoch 4/30, Loss: 11.17046070098877


Epoch 5/30: 100%|##########| 7119/7119 [14:29<00:00,  8.18it/s]


Epoch 5/30, Loss: 9.362608909606934


Epoch 6/30: 100%|##########| 7119/7119 [14:29<00:00,  8.19it/s]


Epoch 6/30, Loss: 10.645035743713379


Epoch 7/30: 100%|##########| 7119/7119 [14:28<00:00,  8.20it/s]


Epoch 7/30, Loss: 10.336000442504883


Epoch 8/30: 100%|##########| 7119/7119 [14:30<00:00,  8.18it/s]


Epoch 8/30, Loss: 9.285191535949707


Epoch 9/30: 100%|##########| 7119/7119 [14:28<00:00,  8.19it/s]


Epoch 9/30, Loss: 9.937674522399902


Epoch 10/30: 100%|##########| 7119/7119 [14:28<00:00,  8.20it/s]


Epoch 10/30, Loss: 7.436448574066162


Epoch 11/30: 100%|##########| 7119/7119 [14:28<00:00,  8.20it/s]


Epoch 11/30, Loss: 7.5819902420043945


Epoch 12/30: 100%|##########| 7119/7119 [14:28<00:00,  8.20it/s]


Epoch 12/30, Loss: 6.005134105682373


Epoch 13/30: 100%|##########| 7119/7119 [14:27<00:00,  8.20it/s]


Epoch 13/30, Loss: 6.872140407562256


Epoch 14/30: 100%|##########| 7119/7119 [14:28<00:00,  8.20it/s]


Epoch 14/30, Loss: 5.648280143737793


Epoch 15/30: 100%|##########| 7119/7119 [14:28<00:00,  8.20it/s]


Epoch 15/30, Loss: 7.23828649520874


Epoch 16/30: 100%|##########| 7119/7119 [14:28<00:00,  8.20it/s]


Epoch 16/30, Loss: 7.725142002105713


Epoch 17/30: 100%|##########| 7119/7119 [14:28<00:00,  8.20it/s]


Epoch 17/30, Loss: 7.167911052703857


Epoch 18/30: 100%|##########| 7119/7119 [14:28<00:00,  8.20it/s]


Epoch 18/30, Loss: 6.90154504776001


Epoch 19/30: 100%|##########| 7119/7119 [14:28<00:00,  8.20it/s]


Epoch 19/30, Loss: 8.444781303405762


Epoch 20/30: 100%|##########| 7119/7119 [14:28<00:00,  8.20it/s]


Epoch 20/30, Loss: 6.428404808044434


Epoch 21/30: 100%|##########| 7119/7119 [14:48<00:00,  8.01it/s]


Epoch 21/30, Loss: 6.322268962860107


Epoch 22/30: 100%|##########| 7119/7119 [14:29<00:00,  8.19it/s]


Epoch 22/30, Loss: 7.900283336639404


Epoch 23/30: 100%|##########| 7119/7119 [14:21<00:00,  8.26it/s]


Epoch 23/30, Loss: 6.582077980041504


Epoch 24/30: 100%|##########| 7119/7119 [14:50<00:00,  7.99it/s]


Epoch 24/30, Loss: 6.9685797691345215


Epoch 25/30: 100%|##########| 7119/7119 [15:11<00:00,  7.81it/s]


Epoch 25/30, Loss: 7.759799480438232


Epoch 26/30: 100%|##########| 7119/7119 [15:19<00:00,  7.74it/s]


Epoch 26/30, Loss: 4.772459983825684


Epoch 27/30: 100%|##########| 7119/7119 [15:18<00:00,  7.75it/s]


Epoch 27/30, Loss: 5.544327735900879


Epoch 28/30: 100%|##########| 7119/7119 [15:04<00:00,  7.87it/s]


Epoch 28/30, Loss: 5.775236129760742


Epoch 29/30: 100%|##########| 7119/7119 [14:53<00:00,  7.96it/s]

Epoch 29/30, Loss: 6.29815149307251





# Test

In [3]:
def unique_image(pair_list) -> set:
    """Return unique image path in pair_list.txt"""
    with open(pair_list, 'r') as fd:
        pairs = fd.readlines()
    unique = set()
    for pair in pairs:
        id1, id2, _ = pair.split()
        unique.add(id1)
        unique.add(id2)
    return unique

In [4]:
def group_image(images: set, batch) -> list:
    """Group image paths by batch size"""
    images = list(images)
    size = len(images)
    res = []
    for i in range(0, size, batch):
        end = min(batch + i, size)
        res.append(images[i : end])
    return res

In [5]:
def _preprocess(images: list, transform, low_light = False) -> torch.Tensor:
    res = []
    for img in images:
        im = Image.open(img)
        if low_light:
            im = (np.sqrt (im)*2).astype(np.uint8)
            im = Image.fromarray(im)
        im = transform(im)
        res.append(im)
    data = torch.cat(res, dim=0)  # shape: (batch, 128, 128)
    data = data[:, None, :, :]    # shape: (batch, 1, 128, 128)
    return data

In [6]:
def featurize(images: list, transform, net, device,low_light = False) -> dict:
    """featurize each image and save into a dictionary
    Args:
        images: image paths
        transform: test transform
        net: pretrained model
        device: cpu or cuda
    Returns:
        Dict (key: imagePath, value: feature)
    """
    data = _preprocess(images, transform,low_light = low_light)
    data = data.to(device)
    net = net.to(device)
    with torch.no_grad():
        features = net(data) 
    res = {img: feature for (img, feature) in zip(images, features)}
    return res

In [7]:
def cosin_metric(x1, x2):
    return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))

In [8]:
def threshold_search(y_score, y_true):
    y_score = np.asarray(y_score)
    y_true = np.asarray(y_true)
    best_acc = 0
    best_th = 0
    for i in range(len(y_score)):
        th = y_score[i]
        y_test = (y_score >= th)
        acc = np.mean((y_test == y_true).astype(int))
        if acc > best_acc:
            best_acc = acc
            best_th = th
    return best_acc, best_th

In [9]:
def compute_accuracy(feature_dict, pair_list, test_root):
    with open(pair_list, 'r') as f:
        pairs = f.readlines()

    similarities = []
    labels = []
    for pair in pairs:
        img1, img2, label = pair.split()
        img1 = osp.join(test_root, img1)
        img2 = osp.join(test_root, img2)
        feature1 = feature_dict[img1].cpu().numpy()
        feature2 = feature_dict[img2].cpu().numpy()
        label = int(label)

        similarity = cosin_metric(feature1, feature2)
        similarities.append(similarity)
        labels.append(label)

    accuracy, threshold = threshold_search(similarities, labels)
    return accuracy, threshold

In [11]:
model = FaceMobileNet(conf.embedding_size)
model = nn.DataParallel(model)
model.load_state_dict(torch.load(conf.test_model, map_location=conf.device))
model.eval()

images = unique_image(conf.test_list)
images = [osp.join(conf.test_root, img) for img in images]
groups = group_image(images, conf.test_batch_size)

feature_dict = dict()
for group in groups:
    d = featurize(group, conf.test_transform, model, conf.device,low_light = False)
    feature_dict.update(d) 
accuracy, threshold = compute_accuracy(feature_dict, conf.test_list, conf.test_root) 

print(
    f"Test Model: {conf.test_model}\n"
    f"Accuracy: {accuracy:.3f}\n"
    f"Threshold: {threshold:.3f}\n"
)

Test Model: checkpoints/26.pth
Accuracy: 0.945
Threshold: 0.308

