Для начала обучим нейросеть для распознавания лиц на датасете CelebA-500.

Для этого будем дообучать нейросеть ResNet50, обученную на датасете ImageNet.

In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary

import numpy as np
import math
import PIL
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm, tqdm_notebook

from torchvision.models import resnet18, resnet50

In [2]:
# Работа была проведена на локальной машине, поэтому подключение Google Disk не требуется
# from google.colab import drive

# drive.mount('/content/drive')

In [3]:
image_size = 224

# Определим аугментацию для тренировки и теста модели
transform_train = transforms.Compose([
    transforms.Resize([image_size, image_size]),
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
    transforms.RandomGrayscale(p=0.3),
    # transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 3)),
    transforms.Normalize(
        (0.485, 0.456, 0.406),
        (0.229, 0.224, 0.225)),
])

transform_test = transforms.Compose([
    transforms.Resize([image_size, image_size]),
    transforms.ToTensor(),
    transforms.Normalize(
        (0.485, 0.456, 0.406),
        (0.229, 0.224, 0.225)),
])

path_to_anno = './celebA_train_500/celebA_anno.txt'
path_to_split = './celebA_train_500/celebA_train_split.txt'
path_to_imgs = './celebA_train_500/celebA_imgs/'

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
from collections import defaultdict

# file with query part annotations: which image belongs to which class
# format:
#     image_name_1.jpg 2678
#     image_name_2.jpg 2679

f = open(path_to_anno, 'r')
img_lines = f.readlines()
f.close()
img_lines = [x.strip('\n').split(' ') for x in img_lines]

f = open(path_to_split, 'r')
img_train_split = f.readlines()
f.close()
img_train_split = [x.strip('\n').split(' ') for x in img_train_split]

# plain list of image names from query. Needed to compute embeddings for query
img_names = [x[0] for x in img_lines]
img_labels = [int(x[1]) for x in img_lines]
img_split = [int(x[1]) for x in img_train_split]

for idx in range(len(img_lines)):
    img_lines[idx].append(img_split[idx])


In [5]:
# Определим тренировачные, валидационные и тестовые данные из файла celebA_anno.txt

x_train, x_val, x_test, y_train, y_val, y_test = [], [], [], [], [], []

np.random.shuffle(img_lines)
for line in img_lines:
    if line[2] == 0:
        x_train.append(line[0])
        y_train.append(int(line[1]))
    elif line[2] == 1:
        x_val.append(line[0])
        y_val.append(int(line[1]))
    elif line[2] == 2:
        x_test.append(line[0])
        y_test.append(int(line[1]))

In [6]:
print(len(x_train), len(x_val), len(x_test))

8544 1878 1589


In [7]:
class FacesDataset(Dataset):
    """
    Датасет с картинками, который паралельно подгружает их из папок
    производит скалирование и превращение в торчевые тензоры
    """
    def __init__(self, files, labels, mode, transform=transform_train):
        super().__init__()
        # список файлов для загрузки
        self.files = files

        self.len_ = len(self.files)

        self.labels = torch.LongTensor(labels)

        if mode == 'train':
            self.transform = transform_train
        else:
            self.transform = transform_test

    def __len__(self):
        return self.len_

    def __getitem__(self, index):
        # load image
        cur_path = path_to_imgs + self.files[index]
        cur_img = PIL.Image.open(cur_path).convert('RGB')
        cur_img = self.transform(cur_img)

        return cur_img, self.labels[index]

    def show_img(self, index):
        # load image
        cur_path = path_to_imgs + self.files[index]
        cur_img = PIL.Image.open(cur_path).convert('RGB')
        cur_img = self.transform(cur_img)
        cur_img.show()

In [8]:
# Будем размораживать слои на 13 эпохе
def adjust_freezing(model, epoch):
    layers_to_unfreeze = [model.layer3, model.layer4, model.avgpool]
    if epoch == 13:
        for layer in layers_to_unfreeze:
            for param in layer.parameters():
                param.requires_grad = True
        summary(model, input_size=(3, 224, 224))

def fit_epoch(model, train_loader, criterion, optimizer, scheduler):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    processed_data = 0

    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        preds = torch.argmax(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
        processed_data += inputs.size(0)

    scheduler.step()
    train_loss = running_loss / processed_data
    train_acc = running_corrects.cpu().numpy() / processed_data
    return train_loss, train_acc

def eval_epoch(model, val_loader, criterion):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    processed_size = 0

    for inputs, labels in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            preds = torch.argmax(outputs, 1)

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
        processed_size += inputs.size(0)
    val_loss = running_loss / processed_size
    val_acc = running_corrects.double() / processed_size
    return val_loss, val_acc

def train(train_dataset, val_dataset, model, epochs, batch_size, optimizer, scheduler):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    history = []
    log_template = "\nEpoch {ep:03d} train_loss: {t_loss:0.4f} \
    val_loss {v_loss:0.4f} train_acc {t_acc:0.4f} val_acc {v_acc:0.4f}"

    with tqdm(desc="epoch", total=epochs) as pbar_outer:
        criterion = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            adjust_freezing(model, epoch)
            train_loss, train_acc = fit_epoch(model, train_loader, criterion, optimizer, scheduler)

            val_loss, val_acc = eval_epoch(model, val_loader, criterion)
            history.append((train_loss, train_acc, val_loss, val_acc))

            pbar_outer.update(1)
            tqdm.write(log_template.format(ep=epoch+1, t_loss=train_loss,\
                                           v_loss=val_loss, t_acc=train_acc, v_acc=val_acc))

    return history

In [9]:
def predict(model, test_loader):
    with torch.no_grad():
        logits = []

        for inputs in test_loader:
            inputs = inputs.to(device)
            model.eval()
            outputs = model(inputs).cpu()
            logits.append(outputs)

    probs = nn.functional.softmax(torch.cat(logits), dim=-1).numpy()
    return probs

In [10]:
def test_model(model, test_dataset):
    model.eval()
    running_corrects = 0
    processed_size = 0
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            preds = torch.argmax(outputs, 1)

        running_corrects += torch.sum(preds == labels.data)
        processed_size += inputs.size(0)
    test_acc = running_corrects.cpu().numpy() / processed_size
    return test_acc

In [11]:
# Дообучим модель resnet50
n_classes = len(np.unique(img_labels))

model = resnet50(pretrained=True)
fc_in_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.ReLU(),
    nn.Linear(fc_in_features, n_classes)
)

model.to(device)



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [12]:
# Заморозим первые 9 слоев сети, т.к. они выделяют низкоуровневые признаки, обучение с нуля которых
# может ухудшить качество наших предсказаний
counter = 0

for child in model.children():
    print(child)
    counter += 1
    if counter < 9:
        for param in child.parameters():
            param.requires_grad = False
print(counter)

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU(inplace=True)
MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, t

In [13]:
#  Посмотрим на слои модели и удостоверимся, что первые 9 из них заморожены
model.to(device)
summary(model, input_size=(3, 224, 224))
print(1)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,

In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, 1, 0.1, 10)

train_dataset = FacesDataset(x_train, y_train, mode='train')
val_dataset = FacesDataset(x_val, y_val, mode='val')
test_dataset = FacesDataset(x_test, y_test, mode='test')

In [15]:
epochs = 40
batch_size = 128

history = train(train_dataset, val_dataset, model, epochs, batch_size, optimizer, scheduler)

epoch:   2%|█▉                                                                          | 1/40 [01:06<43:19, 66.66s/it]


Epoch 001 train_loss: 6.4214     val_loss 5.7288 train_acc 0.0085 val_acc 0.0229


epoch:   5%|███▊                                                                        | 2/40 [02:15<42:53, 67.71s/it]


Epoch 002 train_loss: 5.4574     val_loss 5.2632 train_acc 0.0545 val_acc 0.0836


epoch:   8%|█████▋                                                                      | 3/40 [03:24<42:06, 68.28s/it]


Epoch 003 train_loss: 4.9241     val_loss 4.9745 train_acc 0.1077 val_acc 0.1145


epoch:  10%|███████▌                                                                    | 4/40 [04:33<41:17, 68.82s/it]


Epoch 004 train_loss: 4.5388     val_loss 4.7540 train_acc 0.1633 val_acc 0.1283


epoch:  12%|█████████▌                                                                  | 5/40 [05:43<40:25, 69.31s/it]


Epoch 005 train_loss: 4.2559     val_loss 4.6072 train_acc 0.2015 val_acc 0.1523


epoch:  15%|███████████▍                                                                | 6/40 [06:53<39:22, 69.49s/it]


Epoch 006 train_loss: 4.0291     val_loss 4.4635 train_acc 0.2423 val_acc 0.1778


epoch:  18%|█████████████▎                                                              | 7/40 [08:03<38:18, 69.65s/it]


Epoch 007 train_loss: 3.8737     val_loss 4.3644 train_acc 0.2659 val_acc 0.1991


epoch:  20%|███████████████▏                                                            | 8/40 [09:13<37:14, 69.84s/it]


Epoch 008 train_loss: 3.7268     val_loss 4.3003 train_acc 0.2930 val_acc 0.2050


epoch:  22%|█████████████████                                                           | 9/40 [10:23<36:02, 69.76s/it]


Epoch 009 train_loss: 3.6138     val_loss 4.2327 train_acc 0.3198 val_acc 0.2178


epoch:  25%|██████████████████▊                                                        | 10/40 [11:29<34:16, 68.56s/it]


Epoch 010 train_loss: 3.5203     val_loss 4.1848 train_acc 0.3384 val_acc 0.2274


epoch:  28%|████████████████████▋                                                      | 11/40 [12:35<32:43, 67.71s/it]


Epoch 011 train_loss: 3.4614     val_loss 4.1640 train_acc 0.3621 val_acc 0.2332


epoch:  30%|██████████████████████▌                                                    | 12/40 [13:42<31:32, 67.59s/it]


Epoch 012 train_loss: 3.4482     val_loss 4.1442 train_acc 0.3600 val_acc 0.2348


epoch:  32%|████████████████████████▍                                                  | 13/40 [14:48<30:07, 66.96s/it]


Epoch 013 train_loss: 3.4175     val_loss 4.1364 train_acc 0.3663 val_acc 0.2370
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13 

epoch:  35%|██████████████████████████▎                                                | 14/40 [16:04<30:15, 69.83s/it]


Epoch 014 train_loss: 2.6938     val_loss 2.8365 train_acc 0.4567 val_acc 0.4116


epoch:  38%|████████████████████████████▏                                              | 15/40 [17:21<30:01, 72.06s/it]


Epoch 015 train_loss: 1.8977     val_loss 2.4362 train_acc 0.6251 val_acc 0.4840


epoch:  40%|██████████████████████████████                                             | 16/40 [18:39<29:27, 73.63s/it]


Epoch 016 train_loss: 1.4141     val_loss 2.2506 train_acc 0.7463 val_acc 0.5463


epoch:  42%|███████████████████████████████▉                                           | 17/40 [19:56<28:37, 74.70s/it]


Epoch 017 train_loss: 1.0570     val_loss 2.0459 train_acc 0.8271 val_acc 0.5719


epoch:  45%|█████████████████████████████████▊                                         | 18/40 [21:12<27:35, 75.23s/it]


Epoch 018 train_loss: 0.7647     val_loss 1.9439 train_acc 0.8984 val_acc 0.6017


epoch:  48%|███████████████████████████████████▋                                       | 19/40 [22:29<26:28, 75.63s/it]


Epoch 019 train_loss: 0.5616     val_loss 1.8537 train_acc 0.9410 val_acc 0.6129


epoch:  50%|█████████████████████████████████████▌                                     | 20/40 [23:45<25:17, 75.87s/it]


Epoch 020 train_loss: 0.3948     val_loss 1.7353 train_acc 0.9663 val_acc 0.6512


epoch:  52%|███████████████████████████████████████▍                                   | 21/40 [25:02<24:04, 76.03s/it]


Epoch 021 train_loss: 0.2908     val_loss 1.6607 train_acc 0.9793 val_acc 0.6571


epoch:  55%|█████████████████████████████████████████▎                                 | 22/40 [26:18<22:52, 76.23s/it]


Epoch 022 train_loss: 0.2048     val_loss 1.6098 train_acc 0.9906 val_acc 0.6640


epoch:  57%|███████████████████████████████████████████▏                               | 23/40 [27:35<21:37, 76.33s/it]


Epoch 023 train_loss: 0.1511     val_loss 1.5739 train_acc 0.9951 val_acc 0.6613


epoch:  60%|█████████████████████████████████████████████                              | 24/40 [28:51<20:21, 76.35s/it]


Epoch 024 train_loss: 0.1158     val_loss 1.5263 train_acc 0.9970 val_acc 0.6709


epoch:  62%|██████████████████████████████████████████████▉                            | 25/40 [30:08<19:05, 76.39s/it]


Epoch 025 train_loss: 0.0942     val_loss 1.5170 train_acc 0.9974 val_acc 0.6853


epoch:  65%|████████████████████████████████████████████████▊                          | 26/40 [31:24<17:50, 76.44s/it]


Epoch 026 train_loss: 0.0747     val_loss 1.4883 train_acc 0.9982 val_acc 0.6784


epoch:  68%|██████████████████████████████████████████████████▋                        | 27/40 [32:41<16:36, 76.62s/it]


Epoch 027 train_loss: 0.0589     val_loss 1.4368 train_acc 0.9998 val_acc 0.6912


epoch:  70%|████████████████████████████████████████████████████▌                      | 28/40 [33:59<15:21, 76.82s/it]


Epoch 028 train_loss: 0.0465     val_loss 1.4186 train_acc 0.9996 val_acc 0.6944


epoch:  72%|██████████████████████████████████████████████████████▍                    | 29/40 [35:19<14:16, 77.85s/it]


Epoch 029 train_loss: 0.0429     val_loss 1.4124 train_acc 0.9991 val_acc 0.7039


epoch:  75%|████████████████████████████████████████████████████████▎                  | 30/40 [36:40<13:08, 78.83s/it]


Epoch 030 train_loss: 0.0370     val_loss 1.3841 train_acc 0.9996 val_acc 0.7023


epoch:  78%|██████████████████████████████████████████████████████████▏                | 31/40 [38:01<11:56, 79.62s/it]


Epoch 031 train_loss: 0.0323     val_loss 1.3850 train_acc 0.9995 val_acc 0.7103


epoch:  80%|████████████████████████████████████████████████████████████               | 32/40 [39:23<10:41, 80.18s/it]


Epoch 032 train_loss: 0.0261     val_loss 1.3622 train_acc 0.9998 val_acc 0.7039


epoch:  82%|█████████████████████████████████████████████████████████████▉             | 33/40 [40:42<09:19, 79.91s/it]


Epoch 033 train_loss: 0.0241     val_loss 1.3390 train_acc 0.9998 val_acc 0.7135


epoch:  85%|███████████████████████████████████████████████████████████████▊           | 34/40 [41:58<07:52, 78.82s/it]


Epoch 034 train_loss: 0.0203     val_loss 1.3542 train_acc 0.9996 val_acc 0.7050


epoch:  88%|█████████████████████████████████████████████████████████████████▋         | 35/40 [43:15<06:29, 77.99s/it]


Epoch 035 train_loss: 0.0186     val_loss 1.3214 train_acc 0.9998 val_acc 0.7125


epoch:  90%|███████████████████████████████████████████████████████████████████▌       | 36/40 [44:31<05:09, 77.42s/it]


Epoch 036 train_loss: 0.0168     val_loss 1.3269 train_acc 0.9999 val_acc 0.7188


epoch:  92%|█████████████████████████████████████████████████████████████████████▍     | 37/40 [45:48<03:51, 77.28s/it]


Epoch 037 train_loss: 0.0144     val_loss 1.3096 train_acc 1.0000 val_acc 0.7173


epoch:  95%|███████████████████████████████████████████████████████████████████████▎   | 38/40 [47:05<02:34, 77.36s/it]


Epoch 038 train_loss: 0.0132     val_loss 1.2934 train_acc 0.9999 val_acc 0.7226


epoch:  98%|█████████████████████████████████████████████████████████████████████████▏ | 39/40 [48:21<01:17, 77.05s/it]


Epoch 039 train_loss: 0.0119     val_loss 1.3028 train_acc 1.0000 val_acc 0.7231


epoch: 100%|███████████████████████████████████████████████████████████████████████████| 40/40 [49:39<00:00, 74.48s/it]


Epoch 040 train_loss: 0.0113     val_loss 1.2925 train_acc 0.9996 val_acc 0.7284





In [16]:
# Сохраняем модель при необходимости
# torch.save(model.state_dict(), 'model_weights.pth')

# Очищаем память от посчитанных градиентов
torch.cuda.empty_cache()

In [15]:
# Загружаем модель при необходимости
# model.load_state_dict(torch.load('model_weights.pth', weights_only=True))

<All keys matched successfully>

In [17]:
# Тестируем модель
test_acc = test_model(model, test_dataset)
print(f'Test Accuracy: {test_acc:.2f}')

Test Accuracy: 0.71


In [18]:
# Если вы работаете с данными, которые даны по ссылке,
# то эта ячейка поможет их загрузить
from collections import defaultdict
import os 
# file with query part annotations: which image belongs to which class
# format:
#     image_name_1.jpg 2678
#     image_name_2.jpg 2679
f = open('./celebA_ir/celebA_anno_query.csv', 'r')
query_lines = f.readlines()[1:]
f.close()
query_lines = [x.strip().split(',') for x in query_lines]
# plain list of image names from query. Neede to compute embeddings for query
query_img_names = [x[0] for x in query_lines]

# dictionary with info of which images from query belong to which class
# format:
#     {class: [image_1, image_2, ...]}
query_dict = defaultdict(list)
for img_name, img_class in query_lines:
  query_dict[img_class].append(img_name)

# list of distractor images
distractors_img_names = os.listdir('./celebA_ir/celebA_distractors')

In [19]:
def compute_embeddings(model, images_list):
    '''
    compute embeddings from the trained model for list of images.
    params:
    model: trained nn model that takes images and outputs embeddings
    images_list: list of images paths to compute embeddings for
    output:
    list: list of model embeddings. Each embedding corresponds to images
          names from images_list
    '''
    embeddings = []
    model.eval()
    with torch.no_grad():
        for idx in range(len(images_list)):
            img_name = images_list[idx]
            if images_list == query_img_names:
                image = PIL.Image.open('celebA_ir/celebA_query/'+img_name).convert('RGB')
            elif images_list == distractors_img_names:
                image = PIL.Image.open('celebA_ir/celebA_distractors/'+img_name).convert('RGB')
            else:
                raise Exception('Invalid list').with_traceback(traceback_obj)

            image = transform_test(image)
            image = image.to(device)
            output = model(image.unsqueeze(0))
            embeddings.append(output.detach().cpu())
    
    return torch.cat(embeddings).numpy()

In [20]:
query_embeddings = compute_embeddings(model, query_img_names)
distractors_embeddings = compute_embeddings(model, distractors_img_names)

In [21]:
def compute_cosine_query_pos(query_dict, query_img_names, query_embeddings):
    '''
    compute cosine similarities between positive pairs from query (stage 1)
    params:
    query_dict: dict {class: [image_name_1, image_name_2, ...]}. Key: class in
                the dataset. Value: images corresponding to that class
    query_img_names: list of images names
    query_embeddings: list of embeddings corresponding to query_img_names
    output:
    list of floats: similarities between embeddings corresponding
                    to the same people from query list
    '''
    
    img_to_emb = dict(zip(query_img_names, query_embeddings))
    result = []
    for idx, values in query_dict.items():
        for i in range(len(values)-1):
            for j in range(i+1, len(values)):
                cos_sim = cosine_similarity(np.array(img_to_emb[values[i]]).reshape(1,-1), np.array(img_to_emb[values[j]]).reshape(1,-1))
                result.append(cos_sim[0][0])
    return np.array(result)

def compute_cosine_query_neg(query_dict, query_img_names, query_embeddings):
    '''
    compute cosine similarities between negative pairs from query (stage 2)
    params:
    query_dict: dict {class: [image_name_1, image_name_2, ...]}. Key: class in
                the dataset. Value: images corresponding to that class
    query_img_names: list of images names
    query_embeddings: list of embeddings corresponding to query_img_names
    output:
    list of floats: similarities between embeddings corresponding
                    to different people from query list
    '''
    img_to_emb = dict(zip(query_img_names, query_embeddings))
    result = []
    keys = list(query_dict.keys())
    for i in range(len(keys)-1):
        values = query_dict[keys[i]]
        temp = []
        for j in range(i+1, len(keys)):
            for img in query_dict[keys[j]]:
                temp.append(img)
        for j in range(len(values)):
            for k in range(len(temp)):
                cos_sim = cosine_similarity(np.array(img_to_emb[values[j]]).reshape(1,-1), np.array(img_to_emb[temp[k]]).reshape(1,-1))
                result.append(cos_sim[0][0])
            
    return np.array(result)

def compute_cosine_query_distractors(query_embeddings, distractors_embeddings):
    '''
    compute cosine similarities between negative pairs from query and distractors
    (stage 3)
    params:
    query_embeddings: list of embeddings corresponding to query_img_names
    distractors_embeddings: list of embeddings corresponding to distractors_img_names
    output:
    list of floats: similarities between pairs of people (q, d), where q is
                    embedding corresponding to photo from query, d —
                    embedding corresponding to photo from distractors
    '''
    return np.array(cosine_similarity(query_embeddings, distractors_embeddings)).flatten()

In [22]:
cosine_query_pos = compute_cosine_query_pos(query_dict, query_img_names,
                                            query_embeddings)
cosine_query_neg = compute_cosine_query_neg(query_dict, query_img_names,
                                            query_embeddings)
cosine_query_distractors = compute_cosine_query_distractors(query_embeddings,
                                                            distractors_embeddings)


Ячейка ниже проверяет, что код работает верно:

In [23]:
test_query_dict = {
    2876: ['1.jpg', '2.jpg', '3.jpg'],
    5674: ['5.jpg'],
    864:  ['9.jpg', '10.jpg'],
}
test_query_img_names = ['1.jpg', '2.jpg', '3.jpg', '5.jpg', '9.jpg', '10.jpg']
test_query_embeddings = [
                    [1.56, 6.45,  -7.68],
                    [-1.1 , 6.11,  -3.0],
                    [-0.06,-0.98,-1.29],
                    [8.56, 1.45,  1.11],
                    [0.7,  1.1,   -7.56],
                    [0.05, 0.9,   -2.56],
]

test_distractors_img_names = ['11.jpg', '12.jpg', '13.jpg', '14.jpg', '15.jpg']

test_distractors_embeddings = [
                    [0.12, -3.23, -5.55],
                    [-1,   -0.01, 1.22],
                    [0.06, -0.23, 1.34],
                    [-6.6, 1.45,  -1.45],
                    [0.89,  1.98, 1.45],
]

test_cosine_query_pos = compute_cosine_query_pos(test_query_dict, test_query_img_names,
                                            test_query_embeddings)
test_cosine_query_neg = compute_cosine_query_neg(test_query_dict, test_query_img_names,
                                            test_query_embeddings)
test_cosine_query_distractors = compute_cosine_query_distractors(test_query_embeddings,
                                                            test_distractors_embeddings)

In [24]:
true_cosine_query_pos = [0.8678237233650096, 0.21226104378511604,
                         -0.18355866977496182, 0.9787437979250561]
assert np.allclose(sorted(test_cosine_query_pos), sorted(true_cosine_query_pos)), \
      "A mistake in compute_cosine_query_pos function"

true_cosine_query_neg = [0.15963231223161822, 0.8507997093616965, 0.9272761484302097,
                         -0.0643994061127092, 0.5412660901220571, 0.701307100338029,
                         -0.2372575528216902, 0.6941032794522218, 0.549425446066643,
                         -0.011982733001947084, -0.0466679194884999]
assert np.allclose(sorted(test_cosine_query_neg), sorted(true_cosine_query_neg)), \
      "A mistake in compute_cosine_query_neg function"

true_cosine_query_distractors = [0.3371426578637511, -0.6866465610863652, -0.8456563512871669,
                                 0.14530087113136106, 0.11410510307646118, -0.07265097629002357,
                                 -0.24097699660707042,-0.5851992679925766, 0.4295494455718534,
                                 0.37604478596058194, 0.9909483738948858, -0.5881093317868022,
                                 -0.6829712976642919, 0.07546364489032083, -0.9130970963915521,
                                 -0.17463101988684684, -0.5229363015558941, 0.1399896725311533,
                                 -0.9258034013399499, 0.5295114163723346, 0.7811585442749943,
                                 -0.8208760031249596, -0.9905139680301821, 0.14969764653247228,
                                 -0.40749654525418444, 0.648660814944824, -0.7432584300096284,
                                 -0.9839696492435877, 0.2498741082804709, -0.2661183373780491]
assert np.allclose(sorted(test_cosine_query_distractors), sorted(true_cosine_query_distractors)), \
      "A mistake in compute_cosine_query_distractors function"

И, наконец, финальная функция, которая считает IR metric:

In [25]:
def compute_ir(cosine_query_pos, cosine_query_neg, cosine_query_distractors,
               fpr=0.1):
    '''
    compute identification rate using precomputer cosine similarities between pairs
    at given fpr
    params:
    cosine_query_pos: cosine similarities between positive pairs from query
    cosine_query_neg: cosine similarities between negative pairs from query
    cosine_query_distractors: cosine similarities between negative pairs
                              from query and distractors
    fpr: false positive rate at which to compute TPR
    output:
    float: threshold for given fpr
    float: TPR at given FPR
    '''
    N = int(fpr * (len(cosine_query_neg) + len(cosine_query_distractors)))
    false_pairs = np.concatenate([cosine_query_neg, cosine_query_distractors])
    false_pair = np.sort(false_pairs)[::-1]
    threshold = false_pair[N] 
    cosine_query_pos = np.array(cosine_query_pos)
    TPR = len(cosine_query_pos[cosine_query_pos >= threshold]) / len(cosine_query_pos) 
    return threshold, TPR

И ячейки для ее проверки:

In [26]:
test_thr = []
test_tpr = []
for fpr in [0.5, 0.3, 0.1]:
  x, y = compute_ir(test_cosine_query_pos, test_cosine_query_neg,
                    test_cosine_query_distractors, fpr=fpr)
  test_thr.append(x)
  test_tpr.append(y)

In [27]:
true_thr = [-0.011982733001947084, 0.3371426578637511, 0.701307100338029]
assert np.allclose(np.array(test_thr), np.array(true_thr)), "A mistake in computing threshold"

true_tpr = [0.75, 0.5, 0.5]
assert np.allclose(np.array(test_tpr), np.array(true_tpr)), "A mistake in computing tpr"

А в ячейке ниже вы можете посчитать TPR@FPR для датасета с лицами. Давайте, например, посчитаем для значений fpr = [0.5, 0.2, 0.1, 0.05].

In [28]:
fprs = [0.5, 0.2, 0.1, 0.05]

for fpr in fprs:
    x, y = compute_ir(cosine_query_pos, cosine_query_neg, cosine_query_distractors, 
               fpr)
    print(f'threshold for {fpr}: {y:.2f}')


threshold for 0.5: 0.86
threshold for 0.2: 0.59
threshold for 0.1: 0.42
threshold for 0.05: 0.29


In [29]:
# Очищаем память
torch.cuda.empty_cache()

In [30]:
class ArcMarginProduct(nn.Module):
    '''
    Implement of large margin arc distance: :
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        s: norm of input feature
        m: margin
        cos(theta + m)
    '''
    def __init__(self, in_features, out_features, s=30.0, m=2, easy_margin=False, ls_eps=0.0):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m
        
    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cosine.size(), device='cuda')
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

In [31]:
# Дообучим модель на основе модели ResNet50
model_afl = resnet50(pretrained=True)

# Заменим fc слой, а также добавим слой ArcMarginProduct
fc_in_features = model_afl.fc.in_features
model_afl.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.ReLU(),
    nn.Linear(fc_in_features, n_classes)
)
model_afl.fc1 = ArcMarginProduct(n_classes, n_classes)

model_afl.to(device)

# Заморозим первые 9 слоев сети, т.к. они выделяют низкоуровневые признаки, обучение с нуля которых
# может ухудшить качество наших предсказаний
counter = 0

for child in model_afl.children():
    print(child)
    counter += 1
    if counter < 9:
        for param in child.parameters():
            param.requires_grad = False
print(counter)



Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU(inplace=True)
MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, t

In [32]:
# Обучим модель с изменямым LR на первых 10 эпохах
optimizer = torch.optim.Adam(model_afl.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, 1, 0.1, 10)

epochs = 40
batch_size = 128

history = train(train_dataset, val_dataset, model_afl, epochs, batch_size, optimizer, scheduler)

epoch:   2%|█▉                                                                          | 1/40 [01:05<42:19, 65.12s/it]


Epoch 001 train_loss: 6.4148     val_loss 5.7432 train_acc 0.0095 val_acc 0.0261


epoch:   5%|███▊                                                                        | 2/40 [02:10<41:18, 65.21s/it]


Epoch 002 train_loss: 5.4475     val_loss 5.2844 train_acc 0.0520 val_acc 0.0724


epoch:   8%|█████▋                                                                      | 3/40 [03:16<40:24, 65.52s/it]


Epoch 003 train_loss: 4.9158     val_loss 4.9987 train_acc 0.1053 val_acc 0.1113


epoch:  10%|███████▌                                                                    | 4/40 [04:22<39:26, 65.74s/it]


Epoch 004 train_loss: 4.5545     val_loss 4.7435 train_acc 0.1511 val_acc 0.1353


epoch:  12%|█████████▌                                                                  | 5/40 [05:28<38:24, 65.84s/it]


Epoch 005 train_loss: 4.2605     val_loss 4.6178 train_acc 0.2060 val_acc 0.1576


epoch:  15%|███████████▍                                                                | 6/40 [06:34<37:26, 66.07s/it]


Epoch 006 train_loss: 4.0454     val_loss 4.4771 train_acc 0.2367 val_acc 0.1858


epoch:  18%|█████████████▎                                                              | 7/40 [07:41<36:26, 66.26s/it]


Epoch 007 train_loss: 3.8676     val_loss 4.3894 train_acc 0.2656 val_acc 0.1981


epoch:  20%|███████████████▏                                                            | 8/40 [08:47<35:19, 66.24s/it]


Epoch 008 train_loss: 3.7061     val_loss 4.2922 train_acc 0.3043 val_acc 0.2114


epoch:  22%|█████████████████                                                           | 9/40 [09:54<34:14, 66.27s/it]


Epoch 009 train_loss: 3.6085     val_loss 4.2474 train_acc 0.3257 val_acc 0.2263


epoch:  25%|██████████████████▊                                                        | 10/40 [11:00<33:06, 66.20s/it]


Epoch 010 train_loss: 3.5219     val_loss 4.2000 train_acc 0.3413 val_acc 0.2380


epoch:  28%|████████████████████▋                                                      | 11/40 [12:06<32:03, 66.32s/it]


Epoch 011 train_loss: 3.4716     val_loss 4.1744 train_acc 0.3529 val_acc 0.2417


epoch:  30%|██████████████████████▌                                                    | 12/40 [13:13<30:57, 66.33s/it]


Epoch 012 train_loss: 3.4237     val_loss 4.1697 train_acc 0.3652 val_acc 0.2455


epoch:  32%|████████████████████████▍                                                  | 13/40 [14:18<29:44, 66.09s/it]


Epoch 013 train_loss: 3.4101     val_loss 4.1534 train_acc 0.3703 val_acc 0.2460
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13 

epoch:  35%|██████████████████████████▎                                                | 14/40 [15:36<30:08, 69.55s/it]


Epoch 014 train_loss: 2.6940     val_loss 2.7849 train_acc 0.4458 val_acc 0.4217


epoch:  38%|████████████████████████████▏                                              | 15/40 [16:53<29:59, 71.99s/it]


Epoch 015 train_loss: 1.8624     val_loss 2.5139 train_acc 0.6361 val_acc 0.4803


epoch:  40%|██████████████████████████████                                             | 16/40 [18:11<29:28, 73.68s/it]


Epoch 016 train_loss: 1.4002     val_loss 2.2147 train_acc 0.7433 val_acc 0.5399


epoch:  42%|███████████████████████████████▉                                           | 17/40 [19:29<28:44, 74.97s/it]


Epoch 017 train_loss: 1.0353     val_loss 2.0629 train_acc 0.8318 val_acc 0.5676


epoch:  45%|█████████████████████████████████▊                                         | 18/40 [20:47<27:47, 75.80s/it]


Epoch 018 train_loss: 0.7642     val_loss 1.9365 train_acc 0.8924 val_acc 0.5953


epoch:  48%|███████████████████████████████████▋                                       | 19/40 [22:04<26:42, 76.29s/it]


Epoch 019 train_loss: 0.5495     val_loss 1.8234 train_acc 0.9363 val_acc 0.6177


epoch:  50%|█████████████████████████████████████▌                                     | 20/40 [23:22<25:32, 76.64s/it]


Epoch 020 train_loss: 0.3975     val_loss 1.7168 train_acc 0.9656 val_acc 0.6443


epoch:  52%|███████████████████████████████████████▍                                   | 21/40 [24:39<24:19, 76.79s/it]


Epoch 021 train_loss: 0.2757     val_loss 1.6871 train_acc 0.9823 val_acc 0.6406


epoch:  55%|█████████████████████████████████████████▎                                 | 22/40 [25:56<23:05, 76.96s/it]


Epoch 022 train_loss: 0.2094     val_loss 1.6220 train_acc 0.9898 val_acc 0.6587


epoch:  57%|███████████████████████████████████████████▏                               | 23/40 [27:13<21:49, 77.01s/it]


Epoch 023 train_loss: 0.1453     val_loss 1.5725 train_acc 0.9950 val_acc 0.6693


epoch:  60%|█████████████████████████████████████████████                              | 24/40 [28:31<20:34, 77.18s/it]


Epoch 024 train_loss: 0.1126     val_loss 1.5196 train_acc 0.9975 val_acc 0.6800


epoch:  62%|██████████████████████████████████████████████▉                            | 25/40 [29:48<19:18, 77.26s/it]


Epoch 025 train_loss: 0.0869     val_loss 1.4817 train_acc 0.9982 val_acc 0.6869


epoch:  65%|████████████████████████████████████████████████▊                          | 26/40 [31:05<18:01, 77.27s/it]


Epoch 026 train_loss: 0.0707     val_loss 1.4616 train_acc 0.9985 val_acc 0.6880


epoch:  68%|██████████████████████████████████████████████████▋                        | 27/40 [32:22<16:42, 77.15s/it]


Epoch 027 train_loss: 0.0578     val_loss 1.4503 train_acc 0.9993 val_acc 0.6933


epoch:  70%|████████████████████████████████████████████████████▌                      | 28/40 [33:39<15:23, 77.00s/it]


Epoch 028 train_loss: 0.0477     val_loss 1.4084 train_acc 0.9993 val_acc 0.7023


epoch:  72%|██████████████████████████████████████████████████████▍                    | 29/40 [34:56<14:06, 76.95s/it]


Epoch 029 train_loss: 0.0409     val_loss 1.3815 train_acc 0.9993 val_acc 0.7125


epoch:  75%|████████████████████████████████████████████████████████▎                  | 30/40 [36:13<12:49, 76.94s/it]


Epoch 030 train_loss: 0.0366     val_loss 1.3921 train_acc 1.0000 val_acc 0.7071


epoch:  78%|██████████████████████████████████████████████████████████▏                | 31/40 [37:29<11:31, 76.82s/it]


Epoch 031 train_loss: 0.0304     val_loss 1.3682 train_acc 0.9999 val_acc 0.7103


epoch:  80%|████████████████████████████████████████████████████████████               | 32/40 [38:46<10:13, 76.73s/it]


Epoch 032 train_loss: 0.0252     val_loss 1.3413 train_acc 0.9999 val_acc 0.7130


epoch:  82%|█████████████████████████████████████████████████████████████▉             | 33/40 [40:02<08:56, 76.69s/it]


Epoch 033 train_loss: 0.0238     val_loss 1.3378 train_acc 0.9998 val_acc 0.7215


epoch:  85%|███████████████████████████████████████████████████████████████▊           | 34/40 [41:19<07:39, 76.60s/it]


Epoch 034 train_loss: 0.0202     val_loss 1.3316 train_acc 0.9998 val_acc 0.7157


epoch:  88%|█████████████████████████████████████████████████████████████████▋         | 35/40 [42:36<06:24, 76.90s/it]


Epoch 035 train_loss: 0.0185     val_loss 1.3199 train_acc 1.0000 val_acc 0.7194


epoch:  90%|███████████████████████████████████████████████████████████████████▌       | 36/40 [43:55<05:09, 77.28s/it]


Epoch 036 train_loss: 0.0169     val_loss 1.3237 train_acc 0.9999 val_acc 0.7173


epoch:  92%|█████████████████████████████████████████████████████████████████████▍     | 37/40 [45:13<03:52, 77.61s/it]


Epoch 037 train_loss: 0.0158     val_loss 1.3101 train_acc 0.9998 val_acc 0.7199


epoch:  95%|███████████████████████████████████████████████████████████████████████▎   | 38/40 [46:31<02:35, 77.63s/it]


Epoch 038 train_loss: 0.0137     val_loss 1.2990 train_acc 1.0000 val_acc 0.7226


epoch:  98%|█████████████████████████████████████████████████████████████████████████▏ | 39/40 [47:49<01:17, 77.80s/it]


Epoch 039 train_loss: 0.0126     val_loss 1.2959 train_acc 0.9999 val_acc 0.7295


epoch: 100%|███████████████████████████████████████████████████████████████████████████| 40/40 [49:06<00:00, 73.67s/it]


Epoch 040 train_loss: 0.0124     val_loss 1.2861 train_acc 0.9999 val_acc 0.7258





In [33]:
# Тестируем модель
test_acc = test_model(model_afl, test_dataset)
print(f'Test Accuracy: {test_acc:.2f}')

Test Accuracy: 0.72


In [34]:
# Сохраняем модель при необходимости
# torch.save(model_afl.state_dict(), 'model_weights_arcfaceloss.pth')