In [1]:
%cd ../

/workspaces/mmsegmentation/face_recognition


In [2]:
from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch.nn import DataParallel
import numpy as np
import os
import os.path as osp
from tqdm import tqdm
import shutil
import glob
import random
from icecream import ic
import math

  from .autonotebook import tqdm as notebook_tqdm


# Run parameters

In [3]:
data_dir = 'face_aic'
nc = len(os.listdir('/workspaces/mmsegmentation/custom_data/face_data_aic'))

batch_size = 4
epochs = 30
workers = 0 if os.name == 'nt' else 8

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

Running on device: cuda:0


# Dataset

In [5]:
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
    transforms.RandomRotation(degrees=(30, 70)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]
    )
])

# the validation transforms
val_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]
    )
])

In [6]:
train_dataset = torchvision.datasets.ImageFolder(
    root=osp.join(data_dir, 'train'),
    transform=train_transform
    )
val_dataset = torchvision.datasets.ImageFolder(
    root=osp.join(data_dir, 'val'),
    transform=val_transform
    )
    

In [7]:
train_loader = DataLoader(
    train_dataset,
    num_workers=workers,
    batch_size=batch_size,
    drop_last=True
)
val_loader = DataLoader(
    val_dataset,
    num_workers=workers,
    batch_size=batch_size,
    drop_last=True
)

# Model

In [8]:
class FocalLoss(nn.Module):

    def __init__(self, gamma=0, eps=1e-7):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps
        self.ce = torch.nn.CrossEntropyLoss()

    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        return loss.mean()

In [9]:
class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, train, label=False):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        if train:
            one_hot = torch.zeros(cosine.size(), device='cuda')
            one_hot.scatter_(1, label.cuda().view(-1, 1).long(), 1)
            output = (one_hot * phi) + ((1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
        else:
            output = cosine
        output *= self.s

        return output

In [10]:
resnet = InceptionResnetV1(
    classify=False,
    pretrained='vggface2',
    num_classes=nc
)

In [11]:
resnet

InceptionResnetV1(
  (conv2d_1a): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_2a): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_2b): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (maxpool_3a): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2d_3b): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_4a): 

In [12]:
for param in resnet.parameters():
    param.requires_grad = False

for param in resnet.logits.parameters():
    param.requires_grad = True

In [13]:
class ArcNet(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.base = net
        self.arcface = ArcMarginProduct(512, nc, s=30, m=0.5, easy_margin=False)
    
    def forward(self, x, label):
        x = self.base(x)
        if self.training:
            x = self.arcface(x, self.training, label)
        else:
            x = self.arcface(x, self.training)
        return x

In [14]:
model = ArcNet(resnet).to(device)
# model = DataParallel(model)

# Training

In [15]:
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = MultiStepLR(optimizer, [5, 10])

In [16]:
loss_fn = torch.nn.CrossEntropyLoss()
metrics = {
    'fps': training.BatchTimer(),
    'acc': training.accuracy
}

In [17]:
writer = SummaryWriter()
writer.iteration, writer.interval = 0, 10

print('\n\nInitial')
print('-' * 10)
model.eval()
training.pass_epoch(
    model, loss_fn, val_loader,
    batch_metrics=metrics, show_running=True, device=device,
    writer=writer
)

for epoch in range(epochs):
    print('\nEpoch {}/{}'.format(epoch + 1, epochs))
    print('-' * 10)

    model.train()
    training.pass_epoch(
        model, loss_fn, train_loader, optimizer, scheduler,
        batch_metrics=metrics, show_running=True, device=device,
        writer=writer
    )

    model.eval()
    training.pass_epoch(
        model, loss_fn, val_loader,
        batch_metrics=metrics, show_running=True, device=device,
        writer=writer
    )

writer.close()



Initial
----------
Valid |     5/5    | loss:    5.9090 | fps:  127.7812 | acc:    0.0000   

Epoch 1/30
----------
Train |   127/127  | loss:   19.6982 | fps:  148.9823 | acc:    0.0000   
Valid |     5/5    | loss:    4.9715 | fps:  187.7917 | acc:    0.0000   

Epoch 2/30
----------
Train |   127/127  | loss:   19.3140 | fps:  181.4872 | acc:    0.0000   
Valid |     5/5    | loss:    5.1828 | fps:  182.9131 | acc:    0.0000   

Epoch 3/30
----------
Train |   127/127  | loss:   19.2119 | fps:  174.8084 | acc:    0.0000   
Valid |     5/5    | loss:    5.3494 | fps:  187.3334 | acc:    0.0000   

Epoch 4/30
----------
Train |   127/127  | loss:   19.0874 | fps:  163.0974 | acc:    0.0000   
Valid |     5/5    | loss:    5.3084 | fps:  176.2360 | acc:    0.0000   

Epoch 5/30
----------
Train |   127/127  | loss:   19.0255 | fps:  169.6787 | acc:    0.0000   
Valid |     5/5    | loss:    5.4684 | fps:  187.1928 | acc:    0.0000   

Epoch 6/30
----------
Train |   127/127  | loss: 

In [60]:
torch.save(model.base.state_dict(), 'arcface.pth')

# Testing

In [34]:
def cosine_dist(src, test):
    a = np.multiply(src, test).sum(axis=1)
    b = np.multiply(src, src).sum(axis=1)
    c = np.multiply(test, test).sum(axis=1)
    return 1 - (a / (np.sqrt(b) * np.sqrt(c)))

## Verification

In [52]:
from PIL import Image

img1 = Image.open('/workspaces/mmsegmentation/face_recognition/face_aic/train/Tạ Văn Đại/0.jpg')
img2 = Image.open('/workspaces/mmsegmentation/face_recognition/face_aic/train/Tạ Văn Đại/1.jpg')
img3 = Image.open('/workspaces/mmsegmentation/face_recognition/face_aic/train/Bùi Ánh Hồng/0.jpg')

In [53]:
img1 = val_transform(img1)
img2 = val_transform(img2)
img3 = val_transform(img3)

In [54]:
img1.shape

torch.Size([3, 224, 224])

In [55]:
model.eval()
cur_embeddings = model.base(img1.unsqueeze(0).cuda()).cpu().detach().numpy()
comparing_embeddings = model.base(img2.unsqueeze(0).cuda()).cpu().detach().numpy()
dst = cosine_dist(cur_embeddings, comparing_embeddings)
print(dst)

[0.48646712]


In [None]:
model.eval()
cur_embeddings = model.base(img1.unsqueeze(0).cuda()).cpu().detach().numpy()
comparing_embeddings = model.base(img3.unsqueeze(0).cuda()).cpu().detach().numpy()
dst = cosine_dist(cur_embeddings, comparing_embeddings)
print(dst)

In [None]:
model.eval()
cur_embeddings = model.base(img2.unsqueeze(0).cuda()).cpu().detach().numpy()
comparing_embeddings = model.base(img3.unsqueeze(0).cuda()).cpu().detach().numpy()
dst = cosine_dist(cur_embeddings, comparing_embeddings)
print(dst)

## Identification

In [129]:
img1 = Image.open('/workspaces/mmsegmentation/face_recognition/face_aic/val/Phạm Ánh Nguyệt/0.jpg')
img1 = val_transform(img1)
cur_embeddings = model.base(img1.unsqueeze(0).cuda()).cpu().detach().numpy()

In [130]:
img_paths = glob.glob('/workspaces/mmsegmentation/face_recognition/face_aic/train/*/*.jpg')

In [131]:
dst_ls = []

for img in tqdm(img_paths):
    img2 = Image.open(img)
    img2 = val_transform(img2)
    comparing_embeddings = model.base(img2.unsqueeze(0).cuda()).cpu().detach().numpy()
    dst = cosine_dist(cur_embeddings, comparing_embeddings)
    dst_ls.append(dst)

100%|██████████| 509/509 [00:09<00:00, 53.47it/s]


In [132]:
len(dst_ls)

509

In [133]:
id = np.argmin(dst_ls)
id

61

In [134]:
img_paths[id]

'/workspaces/mmsegmentation/face_recognition/face_aic/train/Phạm Ánh Nguyệt/1.jpg'

In [135]:
dst_ls[id]

array([0.24180788], dtype=float32)