In [10]:
import os
from PIL import Image

from torch.utils.data import Dataset
from torchvision import transforms


class LFW4Training(Dataset):
    def __init__(self, train_file: str, img_folder: str):
        self.img_folder = img_folder

        names = os.listdir(img_folder)
        self.name2label = {name: idx for idx, name in enumerate(names)}
        self.n_label = len(self.name2label)

        with open(train_file) as f:
            train_meta_info = f.read().splitlines()

        self.train_list = []
        for line in train_meta_info:
            line = line.split("\t")
            if len(line) == 3:
                self.train_list.append(os.path.join(line[0], line[0] + "_" + str(line[1]).zfill(4) + ".jpg"))
                self.train_list.append(os.path.join(line[0], line[0] + "_" + str(line[2]).zfill(4) + ".jpg"))
            elif len(line) == 4:
                self.train_list.append(os.path.join(line[0], line[0] + "_" + str(line[1]).zfill(4) + ".jpg"))
                self.train_list.append(os.path.join(line[2], line[2] + "_" + str(line[3]).zfill(4) + ".jpg"))
            else:
                pass

        self.transform = transforms.Compose([
            transforms.Resize(96),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5]),
        ])

    def __getitem__(self, index):
        img_path = self.train_list[index]

        img = Image.open(os.path.join(self.img_folder, img_path))
        img = self.transform(img)

        name = img_path.split("/")[0]
        label = self.name2label[name]

        return img, label

    def __len__(self):
        return len(self.train_list)


class LFW4Eval(Dataset):
    def __init__(self, eval_file: str, img_folder: str):
        self.img_folder = img_folder

        with open(eval_file) as f:
            eval_meta_info = f.read().splitlines()

        self.eval_list = []
        for line in eval_meta_info:
            line = line.split("\t")
            if len(line) == 3:
                eval_pair = (
                    os.path.join(line[0], line[0] + "_" + str(line[1]).zfill(4) + ".jpg"),
                    os.path.join(line[0], line[0] + "_" + str(line[2]).zfill(4) + ".jpg"),
                    1,
                )
                self.eval_list.append(eval_pair)
            elif len(line) == 4:
                eval_pair = (
                    os.path.join(line[0], line[0] + "_" + str(line[1]).zfill(4) + ".jpg"),
                    os.path.join(line[2], line[2] + "_" + str(line[3]).zfill(4) + ".jpg"),
                    0,
                )
                self.eval_list.append(eval_pair)
            else:
                pass

        self.transform = transforms.Compose([
            transforms.Resize(96),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5]),
        ])

    def __getitem__(self, index):
        img_1_path, img_2_path, label = self.eval_list[index]

        img_1 = Image.open(os.path.join(self.img_folder, img_1_path))
        img_2 = Image.open(os.path.join(self.img_folder, img_2_path))
        img_1 = self.transform(img_1)
        img_2 = self.transform(img_2)

        return img_1, img_2, label

    def __len__(self):
        return len(self.eval_list)

In [11]:
import torch
from torch import nn
import torch.nn.functional as F


class AngularPenaltySMLoss(nn.Module):
    def __init__(self, in_features, out_features, eps=1e-7, m=None):
        super(AngularPenaltySMLoss, self).__init__()

        self.m = 4. if not m else m

        self.in_features = in_features
        self.out_features = out_features
        self.fc = nn.Linear(in_features, out_features, bias=False)
        self.eps = eps

    def forward(self, x, labels):

        '''
        input shape (N, in_features)
        '''
        assert len(x) == len(labels)
        assert torch.min(labels) >= 0
        assert torch.max(labels) < self.out_features

        for W in self.fc.parameters():
            W = F.normalize(W, p=2, dim=1)

        x = F.normalize(x, p=2, dim=1)

        wf = self.fc(x)

        numerator = torch.cos(self.m * torch.acos(
            torch.clamp(torch.diagonal(wf.transpose(0, 1)[labels]), -1. + self.eps, 1 - self.eps)))

        excl = torch.cat([torch.cat((wf[i, :y], wf[i, y + 1:])).unsqueeze(0) for i, y in enumerate(labels)], dim=0)
        denominator = torch.exp(numerator) + torch.sum(torch.exp(excl), dim=1)
        L = numerator - torch.log(denominator)

        return -torch.mean(L)


class SphereCNN(nn.Module):
    def __init__(self, class_num: int, feature=False):
        super(SphereCNN, self).__init__()
        self.class_num = class_num
        self.feature = feature

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=128, kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2)

        self.fc5 = nn.Linear(512 * 5 * 5, 512)
        self.angular = AngularPenaltySMLoss(512, self.class_num)

    def forward(self, x, y=None):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))

        x = x.view(x.size(0), -1)
        x = self.fc5(x)

        if self.feature or y is None:
            return x
        else:
            x_angle = self.angular(x, y)
            return x, x_angle



if __name__ == "__main__":
    net = SphereCNN(50)
    input = torch.ones(64, 3, 96, 96)
    output = net(input, None)

In [12]:
import argparse


def parse_args():
    parser = argparse.ArgumentParser(description="SphereFace")

    parser.add_argument('--seed', type=int, default=2021)
    parser.add_argument('--device', type=str, default="cuda:0")

    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epoch', type=int, default=100)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--eval_interval', type=int, default=20)

    parser.add_argument('--train_file', type=str, default='/content/pairsDevTrain.txt')
    parser.add_argument('--eval_file', type=str, default='/content/pairsDevTest.txt')
    parser.add_argument('--img_folder', type=str, default='/content/lfw')

    return parser.parse_args()

In [13]:
import sys

sys.argv = [
    'main.py',  # This is just a placeholder for the script name and isn't used in the argument parsing.
    '--seed', '2021',  # Seed for random number generators for reproducibility.
    '--device', 'cuda:0',  # Specify the device for computation (use 'cpu' if GPU is not available).
    '--batch_size', '128',  # Batch size for training and evaluation.
    '--epoch', '100',  # Number of epochs to train for.
    '--lr', '0.001',  # Learning rate for optimizer.
    '--eval_interval', '20',  # Interval (in epochs) at which to perform evaluation on the validation set.
    '--train_file', '/content/pairsDevTest.txt',  # Path to the training pairs file.
    '--eval_file', '/content/pairsDevTest.txt',  # Path to the evaluation/testing pairs file.
    '--img_folder', '/content/lfw'  # Path to the directory containing the LFW images.
]



In [14]:
import random
import numpy as np

import torch


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    return


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [15]:
import time

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader


def eval(data_loader: DataLoader, model: SphereCNN, device: torch.device, threshold: float = 0.5):
    model.eval()
    model.feature = True
    sim_func = nn.CosineSimilarity()

    cnt = 0.
    total = 0.

    t1 = time.time()
    with torch.no_grad():
        for img_1, img_2, label in data_loader:
            img_1 = img_1.to(device)
            img_2 = img_2.to(device)
            label = label.to(device)

            feat_1 = model(img_1, None)
            feat_2 = model(img_2, None)
            sim = sim_func(feat_1, feat_2)

            sim[sim > threshold] = 1
            sim[sim <= threshold] = 0

            total += sim.size(0)
            for i in range(sim.size(0)):
                if sim[i] == label[i]:
                    cnt += 1

    print("Acc.: %.4f; Time: %.3f" % (cnt / total, time.time() - t1))
    return


def main():
    args = parse_args()

    set_seed(args.seed)
    device = torch.device(args.device)

    train_set = LFW4Training(args.train_file, args.img_folder)
    eval_set = LFW4Eval(args.eval_file, args.img_folder)
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
    eval_loader = DataLoader(eval_set, batch_size=args.batch_size)

    model = SphereCNN(class_num=train_set.n_label)
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    loss_record = AverageMeter()
    for epoch in range(args.epoch):
        t1 = time.time()
        model.train()
        model.feature = False
        loss_record.reset()

        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            _, loss = model(inputs, targets)
            loss.backward()
            optimizer.step()

            loss_record.update(loss)

        print("Epoch: %s; Loss: %.3f; Time: %.3f" % (str(epoch).zfill(2), loss_record.avg, time.time() - t1))

        if (epoch + 1) % args.eval_interval == 0:
            eval(eval_loader, model, device)

    return


if __name__ == "__main__":
    main()

Epoch: 00; Loss: 7.617; Time: 7.276
Epoch: 01; Loss: 7.378; Time: 4.206
Epoch: 02; Loss: 7.115; Time: 3.838
Epoch: 03; Loss: 6.884; Time: 4.180
Epoch: 04; Loss: 6.688; Time: 4.581
Epoch: 05; Loss: 6.528; Time: 3.859
Epoch: 06; Loss: 6.401; Time: 3.831
Epoch: 07; Loss: 6.302; Time: 5.875
Epoch: 08; Loss: 6.223; Time: 4.040
Epoch: 09; Loss: 6.162; Time: 4.445
Epoch: 10; Loss: 6.116; Time: 4.856
Epoch: 11; Loss: 6.076; Time: 3.898
Epoch: 12; Loss: 6.044; Time: 3.861
Epoch: 13; Loss: 6.018; Time: 4.916
Epoch: 14; Loss: 5.995; Time: 3.868
Epoch: 15; Loss: 5.975; Time: 3.848
Epoch: 16; Loss: 5.956; Time: 4.912
Epoch: 17; Loss: 5.938; Time: 3.839
Epoch: 18; Loss: 5.918; Time: 3.924
Epoch: 19; Loss: 5.899; Time: 4.805
Acc.: 0.5000; Time: 3.432
Epoch: 20; Loss: 5.879; Time: 3.834
Epoch: 21; Loss: 5.857; Time: 4.194
Epoch: 22; Loss: 5.833; Time: 5.079
Epoch: 23; Loss: 5.807; Time: 3.880
Epoch: 24; Loss: 5.782; Time: 4.167
Epoch: 25; Loss: 5.755; Time: 4.667
Epoch: 26; Loss: 5.727; Time: 3.889
Ep