In [3]:
from torch import nn
import torch.utils.data as data
import torch
import math
import numpy as np
import os


class FaceFeatureExtractor(nn.Module):
    def __init__(self, latent_c=256):
        super(FaceFeatureExtractor, self).__init__()

        # define layers
        self.leaky_alpha = 0.1
        self.conv_sep = nn.Conv2d(in_channels=128, out_channels=latent_c, kernel_size=(1,1), stride=(1,1), padding=(0,0), groups=1, bias=False)
        self.bn_sep = nn.BatchNorm2d(latent_c)
        # self.relu = nn.ReLU(latent_c)
        self.lrelu = nn.LeakyReLU(negative_slope=self.leaky_alpha)
        self.conv_dw = nn.Conv2d(latent_c, out_channels=latent_c, kernel_size=(7,7), stride=(1,1), padding=(0,0), groups=latent_c, bias=False)
        self.bn_dw = nn.BatchNorm2d(latent_c)
        self.conv_fin = nn.Conv2d(latent_c, out_channels=128, kernel_size=(1,1), stride=(1,1), padding=(0,0), groups=1, bias=False)
        self.flatten = nn.Flatten()

        # init weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
    
    def forward(self, x):
        x = self.conv_sep(x)
        x = self.bn_sep(x)
        x = self.lrelu(x)
        x = self.conv_dw(x)
        x = self.bn_dw(x)
        x = self.conv_fin(x)
        return self.flatten(x)

class LatentData(data.Dataset):
    def __init__(self, path):
        self.paths = np.load(os.path.join(path, "paths.npy"))
        self.labels = np.load(os.path.join(path, "labels.npy"))

    def __getitem__(self, index):
        data = np.load(self.paths[index]).squeeze()
        label = self.labels[index]

        return data, int(label)

    def __len__(self):
        return len(self.labels)

def l2_norm(input,axis=1):
    norm = torch.norm(input,2,axis,True)
    output = torch.div(input, norm)
    return output

class Arcface(nn.Module):
    # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599    
    def __init__(self, embedding_size=128, classnum=51332,  s=64., m=0.5):
        super(Arcface, self).__init__()
        self.classnum = classnum
        self.kernel = nn.Parameter(torch.Tensor(embedding_size,classnum))
        nn.init.xavier_uniform_(self.kernel)
        # initial kernel
        self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
        self.m = m # the margin value, default is 0.5
        self.s = s # scalar value default is 64, see normface https://arxiv.org/abs/1704.06369
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.mm = self.sin_m * m  # issue 1
        self.threshold = math.cos(math.pi - m)
    def forward(self, embbedings, label):
        # weights norm
        nB = len(embbedings)
        kernel_norm = l2_norm(self.kernel,axis=0) # normalize for each column
        # cos(theta+m)
        cos_theta = torch.mm(embbedings,kernel_norm)
#         output = torch.mm(embbedings,kernel_norm)
        cos_theta = cos_theta.clamp(-1,1) # for numerical stability
        cos_theta_2 = torch.pow(cos_theta, 2)
        sin_theta_2 = 1 - cos_theta_2
        sin_theta = torch.sqrt(sin_theta_2)
        cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m)
        # this condition controls the theta+m should in range [0, pi]
        #      0<=theta+m<=pi
        #     -m<=theta<=pi-m
        cond_v = cos_theta - self.threshold
        cond_mask = cond_v <= 0
        keep_val = (cos_theta - self.mm) # when theta not in [0,pi], use cosface instead
        cos_theta_m[cond_mask] = keep_val[cond_mask]
        output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta
        idx_ = torch.arange(0, nB, dtype=torch.long)
        output[idx_, label] = cos_theta_m[idx_, label]
        output *= self.s # scale up in order to make softmax work, first introduced in normface
        return output

In [4]:
import time
from torch.nn.functional import normalize

CASIA_WEBFACE_PATH = "C:\\workspace\\facenet\\CASIA_webface\\"
LATENT_DATA_PATH = "C:\\workspace\\facenet\\latent_train_data_npy\\"
BATCH_SIZE = 64
EPOCHS = 12

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dataset_train = LatentData(LATENT_DATA_PATH)
net = FaceFeatureExtractor(512).to(device)
margin = Arcface(s=32., m=0.5).to(device)

dataloader_train = data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD([{'params': net.parameters(), 'weight_decay': 5e-4}, {'params': margin.parameters(), 'weight_decay': 5e-4}], lr=0.01, momentum=0.9, nesterov=True)

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[6, 8, 10], gamma=0.3)
start = time.time()

train_logging = 'train_logging.txt'

best_acc = 0.
best_iters = 0
total_iters = 0

for epoch in range(EPOCHS):
    net.train()
    since = time.time()
    for det in dataloader_train:
        data, label = det[0].to(device), det[1].to(device)
        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            raw_out = net(data)
            logits = torch.nn.functional.normalize(raw_out)
            margin_out = margin(logits, label)
            loss = criterion(margin_out, label)
            loss.backward()
            optimizer.step()

            total_iters += 1
            if total_iters % 100 == 0:
                _, preds = torch.max(margin_out.data, 1)
                total = label.size(0)
                correct = (np.array(preds.cpu()) == np.array(label.data.cpu())).sum()
                time_cur = (time.time() - since) / 100
                since = time.time()

                for p in optimizer.param_groups:
                    lr = p['lr']
                print("Epoch {}/{}, Iters: {:0>6d}, loss: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}"
                          .format(epoch, EPOCHS-1, total_iters, loss.item(), correct/total, time_cur, lr))
                with open(train_logging, 'a') as f:
                    f.write("Epoch {}/{}, Iters: {:0>6d}, loss: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}"
                        .format(epoch, EPOCHS-1, total_iters, loss.item(), correct/total, time_cur, lr)+'\n')
            
            if total_iters % 3000 == 0:
                torch.save({
                    'iters': total_iters,
                    'net_state_dict': net.state_dict()},
                    'Iter_%06d_model.ckpt' % total_iters)
                torch.save({
                    'iters': total_iters,
                    'net_state_dict': margin.state_dict()},
                    'Iter_%06d_margin.ckpt' % total_iters)
    scheduler.step()

Epoch 0/11, Iters: 000100, loss: 28.1051, train_accuracy: 0.0000, time: 1.09 s/iter, learning rate: 0.01
Epoch 0/11, Iters: 000200, loss: 26.7360, train_accuracy: 0.0000, time: 1.01 s/iter, learning rate: 0.01
Epoch 0/11, Iters: 000300, loss: 25.2333, train_accuracy: 0.0000, time: 1.05 s/iter, learning rate: 0.01
Epoch 0/11, Iters: 000400, loss: 25.5040, train_accuracy: 0.0000, time: 1.17 s/iter, learning rate: 0.01
Epoch 0/11, Iters: 000500, loss: 25.4533, train_accuracy: 0.0000, time: 1.26 s/iter, learning rate: 0.01
Epoch 0/11, Iters: 000600, loss: 23.2878, train_accuracy: 0.0000, time: 1.27 s/iter, learning rate: 0.01
Epoch 0/11, Iters: 000700, loss: 23.8006, train_accuracy: 0.0000, time: 1.35 s/iter, learning rate: 0.01
Epoch 0/11, Iters: 000800, loss: 23.4216, train_accuracy: 0.0000, time: 1.41 s/iter, learning rate: 0.01
Epoch 0/11, Iters: 000900, loss: 23.1530, train_accuracy: 0.0000, time: 1.17 s/iter, learning rate: 0.01
Epoch 0/11, Iters: 001000, loss: 22.8171, train_accurac

In [5]:
torch.save({
    'iters': total_iters,
    'net_state_dict': net.state_dict()},
    'Iter_%06d_model.ckpt' % total_iters)
torch.save({
    'iters': total_iters,
    'net_state_dict': margin.state_dict()},
    'Iter_%06d_margin.ckpt' % total_iters)

In [20]:
dataset_train[0][0].shape

for det in dataloader_train:
    data, label = det[0].to(device), det[1].to(device)
    print(data.shape)
    break

torch.Size([64, 128, 7, 7])


In [21]:
import torch.onnx

dummpy_input = torch.randn(1, 128, 7, 7)

torch.onnx.export(net, dummpy_input, "facenet.onnx", export_params=True, opset_version=14, do_constant_folding=True,
    input_names=["input"], output_names=["output"])