In [2]:
'''
Contains all modules for getting a APFA model.
'''
import torch
from torch import nn
import torchvision.models as models
import random
import torch.nn.functional as F


class Masked_Residual_Aggregation(nn.Module):
    def __init__(self, in_channels : int, num_parts : int):
        '''
            in_channels : a int number, which implies the size of the input tensor to this module.
            num_parts : a int number, which implies the count of hidden layers. They enhanced the layer's features
        '''
        super(Masked_Residual_Aggregation, self).__init__()
        self.in_channels = in_channels
        self.num_parts = num_parts

        # Convolutional 1x1 to get a m Layers(attantion layers)
        self.mask_conv = nn.Conv2d(in_channels, num_parts, kernel_size=1)

        # Convolutional 1x1 to obtain the final h(x) (H x W x (CxM) --> H x W x (C))
        self.final_mask_conv = nn.Conv2d(in_channels * num_parts, in_channels, kernel_size=1)

    def forward(self, x):
        masks = torch.sigmoid(self.mask_conv(x)) # --> H x W x m

        # Expand masks(N_layer) to corresponding size of f(x)
        masks = masks.unsqueeze(2)   # ---> H x W x 1 x m

        # B x C x H x W
        x_exp = x.unsqueeze(1)
        attention_layers = x_exp * masks
        h = attention_layers + x_exp    # --> H x W x C x M

        h = torch.cat([h[:, i, :, :, :] for i in range(self.num_parts)], dim=1)

        h_1 = self.final_mask_conv(h)

        h_output = torch.fft.fft2(h_1, norm='ortho')

        return h_output


class Phase_Based_Augmentation(nn.Module):
    def __init__(self, gamma=0.1):
        '''
            gamma : a int number, which explains adding residual to the phase spectrum of other images in batch.
        '''
        super().__init__()
        self.gamma = gamma

    def forward(self, h_freq):
        batch_size, channels, H, W = h_freq.shape

        amplitude = torch.abs(h_freq)
        phase_origin = torch.angle(h_freq)

        h_ran = h_freq[random.randint(0, batch_size-1)] # [channels, H, W]

        phase_random = torch.angle(h_ran)

        phase_new = self.gamma * phase_random.unsqueeze(0) + (1 - self.gamma) * phase_origin

        h_freq_new = amplitude * torch.exp(1j * phase_new)

        return h_freq_new


class Hybrid_Module(nn.Module):
    # in_channels = кол-во каналов из batch из CNN(HxWxC)
    # num_parts кол-во масок в Masked Residual Aggregation, гипер параматр
    # gamma - для фазовой аугментации
    def __init__(self, in_channels, num_parts, gamma, ifft=True):
        super().__init__()
        self.masked_module = Masked_Residual_Aggregation(in_channels, num_parts)
        self.phase_augm = Phase_Based_Augmentation(gamma)
        self.ifft = ifft

    def forward(self, x):
        x = self.masked_module(x)

        x = self.phase_augm(x)

        if self.ifft:
            x = torch.fft.ifft2(x, norm='ortho').real

        return x


class APFA(nn.Module):
    def __init__(self, in_channels=2048, num_parts=4, gamma=0.1, num_classes=1000):
        super().__init__()

        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.resnet_encoder = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4
        )

        self.hyb_module = Hybrid_Module(in_channels, num_parts, gamma)

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(in_channels, num_classes)
        )

    def forward(self, x):
        x = self.resnet_encoder(x)  # [B, 2048, H, W]

        x = self.hyb_module(x)

        x = self.pool(x)
        x = x.flatten(1)
        x = self.classifier(x)

        return x


class Triplet_Network(nn.Module):
    def __init__(self, original_model_parameters):
        super().__init__()

        self.orig_branch = APFA(**original_model_parameters)

    def forward(self, orig_img):
        norm_emb = F.normalize(self.orig_branch(orig_img), p=2, dim=1)

        return norm_emb


In [3]:
from torchvision.transforms import transforms
import torch
from collections import defaultdict
import glob
import os
from PIL import Image
import random
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from sklearn.neighbors import KNeighborsClassifier
import numpy as np


def coll_fn_augm(batch):
    batch_size = len(batch)
    images, labels = zip(*batch)
    batch = torch.tensor(images)

    alpha = 0.1

    h_freq = torch.fft.fftshift(torch.fft.fft2(batch, norm='ortho'))
    amplitude = torch.abs(h_freq)

    phase_origin = torch.angle(h_freq)

    h_ran = h_freq[random.randint(0, batch_size-1)] # [channels, H, W]
    phase_random = torch.angle(h_ran)
    phase_new = alpha * phase_random.unsqueeze(0) + (1 - alpha) * phase_origin
    h_freq_new = amplitude * torch.exp(1j * phase_new)

    output = torch.fft.ifft2(torch.fft.ifftshift(h_freq_new), norm='ortho').real

    return output


#(64, 512, 3)
norm_transform = transforms.Compose([
    transforms.Resize((64, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

another_transform = transforms.Compose([
    # Геометрические преобразования (применяются к исходному изображению до изменения размера)
    transforms.RandomApply([
        transforms.RandomRotation(10),  # Случайный поворот ±10 градусов
        transforms.RandomPerspective(  # Перспективные искажения
            distortion_scale=0.15,
            p=0.3
        ),
    ], p=0.5),

    # Цветовые преобразования (важно для разных условий освещения)
    transforms.ColorJitter(
        brightness=0.15,  # Яркость
        contrast=0.15,    # Контраст
        saturation=0.1,   # Насыщенность
        hue=0.05          # Оттенок (малое значение для сохранения цветов радужки)
    ),

    # Размытия и шумы
    transforms.RandomApply([
        transforms.GaussianBlur(  # Гауссово размытие
            kernel_size=3,
            sigma=(0.1, 1.5))
    ], p=0.3),

    transforms.Resize((224, 224)),
    transforms.ToTensor(),

    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


# TODO
# написать датасет для разбиения обучающей и тестовой выборки
class Iris_Classification_Dataset(torch.utils.data.Dataset):
    def __init__():
        pass

    def __len__():
        pass

    def __getitem__():
        pass



class IrisDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root=None,
        num_seen_classes=20,
        transform=None,
        mode=None,
        list_files=None
    ):
        self.root = root
        self.transform = transform
        self.mode = mode

        self.class_to_idxs = defaultdict(list)
        self.data = []

        all_patients = sorted(
            glob.glob(os.path.join(self.root, '*')),
            key=lambda x: int(os.path.basename(x))
        )

        if mode == "train":
            self.patients = all_patients[:num_seen_classes]
            self.need_classes = list(range(num_seen_classes * 2))  # Удваиваем количество классов

            for patient_dir in self.patients:
                patient_id = int(os.path.basename(patient_dir))
                for eye_dir in ['L', 'R']:
                    eye_path = os.path.join(patient_dir, eye_dir)
                    if os.path.exists(eye_path):
                        images = glob.glob(os.path.join(eye_path, '*.*'))
                        # Создаем уникальный класс для каждого глаза: L = 2*id, R = 2*id + 1
                        eye_class = 2 * patient_id if eye_dir == 'L' else 2 * patient_id + 1
                        self.data.extend([(img, eye_class) for img in images])

        elif mode in ["test_few", "test_all"]:
            self.patients = all_patients[num_seen_classes:]
            # Для теста берем оставшиеся классы, умноженные на 2
            total_patients = len(all_patients)
            self.need_classes = [2*i for i in range(num_seen_classes, total_patients)]

            for patient_dir in self.patients:
                patient_id = int(os.path.basename(patient_dir))
                for eye_dir in ['L', 'R']:
                    eye_path = os.path.join(patient_dir, eye_dir)
                    if os.path.exists(eye_path):
                        images = glob.glob(os.path.join(eye_path, '*.*'))
                        # Аналогично train: разделяем классы
                        eye_class = 2 * patient_id if eye_dir == 'L' else 2 * patient_id + 1
                        
                        if mode == "test_few":
                            selected_images = images[:1]
                        else:
                            selected_images = images[1:]
                            
                        self.data.extend([(img, eye_class) for img in selected_images])

        else:
            raise ValueError("Invalid mode. Use 'train' or 'test'")

        # Обновляем индексы классов
        for idx, (_, label) in enumerate(self.data):
            self.class_to_idxs[label].append(idx)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path, label = self.data[index]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        #if label > 249:
        #    print(label)
        return image, label


# class IrisDataset(torch.utils.data.Dataset):
#     def __init__(
#         self,
#         root=None,
#         num_seen_classes=20,
#         transform=None,
#         mode=None,
#         list_files=None
#     ):
#         self.root = root
#         self.transform = transform
#         self.mode = mode

#         self.need_classes = []
#         self.class_to_idxs = defaultdict(list)

#         all_patients = sorted(
#             glob.glob(os.path.join(self.root, '*')),
#             key=lambda x: int(os.path.basename(x))
#         )

        # self.data = []
        # if mode == "train":
        #     self.patients = all_patients[:num_seen_classes]
        #     self.need_classes = list(range(num_seen_classes))

        #     for patient_dir in self.patients:
        #         patient_id = int(os.path.basename(patient_dir))
        #         for eye_dir in ['L', 'R']:
        #             eye_path = os.path.join(patient_dir, eye_dir)
        #             if os.path.exists(eye_path):
        #                 images = glob.glob(os.path.join(eye_path, '*.*'))
        #                 self.data.extend([(img, patient_id) for img in images])

    #     elif mode in ["test_few", "test_all"]:
    #         self.patients = all_patients[num_seen_classes:]
    #         self.need_classes = [i for i in range(len(all_patients) - num_seen_classes)]

    #         for patient_dir in self.patients:
    #             patient_id = int(os.path.basename(patient_dir))
    #             patient_id -= 1

    #             if patient_id:
    #                 patient_id += 2
    #             for eye_dir in ['L', 'R']:
    #                 if eye_dir == "R":
    #                     patient_id += 1

    #                 eye_path = os.path.join(patient_dir, eye_dir)
    #                 if os.path.exists(eye_path):
    #                     images = glob.glob(os.path.join(eye_path, '*.*'))
    #                     if mode == "test_few":
    #                         selected_images = images[:1]
    #                     else:
    #                         selected_images = images[1:]
    #                     self.data.extend([(img, patient_id) for img in selected_images])
    #     else:
    #         raise ValueError("Invalid mode. Use 'train' or 'test'")

    #     for idx, (_, label) in enumerate(self.data):
    #         self.class_to_idxs[label].append(idx)

    # def __len__(self):
    #     return len(self.data)

    # def __getitem__(self, index):
    #     img_path, label = self.data[index]
    #     image = Image.open(img_path).convert('RGB')

    #     if self.transform:
    #         image = self.transform(image)
    #     if label > 249:
    #         print(f"label {label}")
    #     return image, label


def random_choice_except(options, exception):
    choice = exception
    while choice == exception:
        choice = random.choice(options)
    return choice


class Triplet(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

    def __getitem__(self, index):
        sample1, target1 = self.dataset[index]

        positive_index = random_choice_except(
            self.dataset.class_to_idxs[target1],
            exception=index,
        )
        sample2, target2 = self.dataset[positive_index]

        negative_target = random_choice_except(
            self.dataset.need_classes,
            exception=target1,
        )

        negative_index = random.choice(
            self.dataset.class_to_idxs[negative_target],
        )

        sample3, target3 = self.dataset[negative_index]

        return [sample1, sample2, sample3], [target1, target2, target3]

    def __len__(self):
        return len(self.dataset)


def get_dataloaders_to_IRIS(
        path=None,
        num_seen=1,
        batch_size=1,
        transform_train=None,
        transform_test=None
    ):

    if transform_test == None:
        transform_test = transform_train

    train_data = IrisDataset(
        path,
        num_seen,
        transform_train,
        "train"
    )

    test_data = IrisDataset(
        path,
        num_seen,
        transform_test,
        "test_all"
    )

    test_dl = DataLoader(
        test_data,
        batch_size,
        num_workers=4,
        pin_memory=True,
    )

    train_dl = DataLoader(
        train_data,
        batch_size,
        num_workers=4,
        pin_memory=True,
        shuffle=True,
    )

    return train_dl, test_dl


def get_dl_2_IRIS(
    path=None,
    num_seen=1,
    batch_size=1,
    transform_train=None,
    transform_test=None
):
    # if transform_test == None:
    #     transform_test = transform_train

    train_data = IrisDataset(
        path,
        num_seen,
        transform_train,
        "train"
    )

    test_data_few = IrisDataset(
        path,
        num_seen,
        transform_train,
        "test_few"
    )

    test_data_all = IrisDataset(
        path,
        num_seen,
        transform_train,
        "test_all"
    )

    test_dl_few = DataLoader(
        test_data_few,
        batch_size,
        num_workers=4,
        pin_memory=True,
    )

    test_dl_all = DataLoader(
        test_data_all,
        batch_size,
        num_workers=4,
        pin_memory=True,
    )

    train_dl = DataLoader(
        train_data,
        batch_size,
        num_workers=4,
        pin_memory=True,
        shuffle=True
    )

    return train_dl, test_dl_few, test_dl_all


def run_emb_net(emb_net, dataloader, device=None, normalize=False):
    data_x = []
    data_y = []
    device = device
    with torch.no_grad():
        for inputs, labels in dataloader:
            feats = emb_net(inputs.to(device))
            if normalize:
                feats = F.normalize(feats)
            feats = feats.detach().cpu().numpy()
            labels = labels.detach().cpu().numpy()
            data_x.append(feats)
            data_y.append(labels)

    data_x = np.concatenate(data_x, axis=0)
    data_y = np.concatenate(data_y, axis=0)
    return data_x, data_y


def train_knn(emb_net, oneshot_dl, device, normalize=False):
    data_x, data_y = run_emb_net(emb_net, oneshot_dl, device, normalize)
    knn = KNeighborsClassifier(n_neighbors=1)
    knn = knn.fit(data_x, data_y)
    return knn


def testing_model(emb_net, test_dl_few, test_dl_all, device=None, normalize=False):
    data_x, data_y = run_emb_net(emb_net, test_dl_all, device, normalize)
    knn = train_knn(emb_net, test_dl_few, device)

    total_acc = 0
    total_cnt = 0
    for feat, label in zip(data_x, data_y):
        pred = knn.predict(feat[None]).squeeze(0)
        total_acc += pred == label
        total_cnt += 1

    acc = total_acc / total_cnt
    print(f"Accuracy = {acc:.2%} ({total_acc} / {total_cnt})")

    return acc, total_acc, total_cnt


def get_embeddings(model, dataloader, device=None):
    embeddings = []
    labels = []

    with torch.no_grad():
        for images, targets in dataloader:
            images = images.to(device)
            outputs = model(images)
            outputs = F.normalize(outputs)
            embeddings.append(outputs.cpu().numpy())
            labels.append(targets.cpu().numpy())

    return np.concatenate(embeddings), np.concatenate(labels)


In [40]:
from collections import defaultdict
import numpy as np
from scipy.spatial.distance import cdist
from sklearn.metrics import roc_curve
import torch
from itertools import combinations
import matplotlib.pyplot as plt
#from model.dataset import get_embeddings
import seaborn as sns
import matplotlib.patheffects as PathEffects
import numpy as np
import os
from datetime import datetime


def calculate_eer(embeddings, labels):
    unique_labels = np.unique(labels)
    genuine_dists = []
    impostor_dists = []

    for label in unique_labels:
        class_indices = np.where(labels == label)[0]
        class_emb = embeddings[class_indices]
        # 1. Внутриклассовые сравнения
        if len(class_indices) >= 2:
            #class_emb = embeddings[class_indices]
            dists = cdist(class_emb, class_emb, 'euclidean')
            genuine = dists[np.triu_indices(len(class_emb), 1)]  # Исправлено
            genuine_dists.extend(genuine)

        # 2. Межклассовые сравнения
        other_indices = np.where(labels != label)[0]
        if len(other_indices) > 0 and len(class_indices) > 0:
            other_emb = embeddings[other_indices]
            imp_dists = cdist(class_emb, other_emb, 'euclidean').flatten()
            impostor_dists.extend(imp_dists)

        # 5. Проверка на пустые списки
    if len(genuine_dists) == 0 or len(impostor_dists) == 0:
        return {
            'eer': 1.0,
            'threshold': None,
            'genuine_mean': 0,
            'impostor_mean': 0,
            'genuine_std': 0,
            'impostor_std': 0
        }

    # 3. Формирование меток
    y_true = np.concatenate([np.ones(len(genuine_dists)),
                            np.zeros(len(impostor_dists))])
    y_score = -np.concatenate([genuine_dists, impostor_dists])

    # 4. Расчет EER
    fpr, tpr, thresholds = roc_curve(y_true, y_score)
    fnr = 1 - tpr
    eer_idxx = np.nanargmin(np.abs(fnr - fpr))


    return {
        'eer': fpr[eer_idxx],
        'threshold': thresholds[eer_idxx],
        'genuine_mean': np.mean(genuine_dists),
        'impostor_mean': np.mean(impostor_dists),
        'genuine_std': np.std(genuine_dists),
        'impostor_std': np.std(impostor_dists)
    }


# Новая функция вычисления EER с учетом реального диапазона расстояний
def cal_eer(target, imposter):
    min_score = min(target.min(), imposter.min())
    max_score = max(target.max(), imposter.max())

    thresholds = torch.linspace(min_score, max_score, 1000)

    fars = torch.tensor([(imposter <= t).float().mean() for t in thresholds])
    frrs = torch.tensor([(target > t).float().mean() for t in thresholds])

    abs_diffs = torch.abs(fars - frrs)
    min_index = torch.argmin(abs_diffs)

    eer = (fars[min_index] + frrs[min_index]) / 2
    eer_threshold = thresholds[min_index]

    return eer, eer_threshold


ROOT_HIST = None
def get_hist(model, test_dl, device=None, out="result.png", subtitle=None, root=None):
    global ROOT_HIST
    if root == None:
        if ROOT_HIST == None:
            current_time = datetime.now()
            formated_time = current_time.strftime("%m-%d_%H:%M:%S")
            ROOT_HIST = f"hist-{subtitle}-{formated_time}"
            root = ROOT_HIST
            if not os.path.exists(ROOT_HIST):
                os.makedirs(ROOT_HIST)
        else:
            root = ROOT_HIST
    else:
        ROOT_HIST = root

    embeddings, labels = get_embeddings(model, test_dl, device)

    # Создаем словарь для группировки
    class_embeddings = {}

    # Проходим по всем эмбеддингам и меткам
    for emb, label in zip(embeddings, labels):
        label = label.item()  # Если метки в тензоре

        if label not in class_embeddings:
            # Создаем новый ключ с добавлением размерности батча
            class_embeddings[label] = torch.tensor(emb).unsqueeze(0)
        else:
            # Конкатенируем с существующими эмбеддингами класса
            class_embeddings[label] = torch.cat([
                class_embeddings[label],
                torch.tensor(emb).unsqueeze(0)
            ], dim=0)

    embeddings = class_embeddings

    all_target_scores = []
    all_imposter_scores = []

    for class_id in embeddings:
        # Target-пары внутри класса
        class_embs = embeddings[class_id]
        if class_embs.shape[0] > 1:
            # Генерация всех уникальных пар
            indices = torch.tensor(list(combinations(range(class_embs.shape[0]), 2)))
            target_pairs_a = class_embs[indices[:, 0]]
            target_pairs_b = class_embs[indices[:, 1]]
            target_scores = torch.norm(target_pairs_a - target_pairs_b, p=2, dim=1)
            all_target_scores.append(target_scores)

        # Imposter-пары с другими классами
        for other_class_id in embeddings:
            if other_class_id != class_id:
                other_embs = embeddings[other_class_id]

                # Генерация всех возможных комбинаций
                class_indices = torch.arange(class_embs.shape[0])
                other_indices = torch.arange(other_embs.shape[0])

                # Декартово произведение индексов
                pairs = torch.cartesian_prod(class_indices, other_indices)

                # Выборка соответствующих эмбеддингов
                imposter_pairs_a = class_embs[pairs[:, 0]]
                imposter_pairs_b = other_embs[pairs[:, 1]]

                imposter_scores = torch.norm(imposter_pairs_a - imposter_pairs_b, p=2, dim=1)
                all_imposter_scores.append(imposter_scores)

    # Объединение результатов
    target_scores = torch.cat(all_target_scores) if all_target_scores else torch.tensor([])
    imposter_scores = torch.cat(all_imposter_scores)


    # Гистограммы

    # Пересчет и построение графика с новой EER-точкой
    eer, eer_threshold = cal_eer(target_scores, imposter_scores)

    plt.figure(figsize=(10, 6))

    total_samples_target = target_scores.size()
    total_samples_imposter = imposter_scores.size()

    # Построение гистограммы с вероятностями
    plt.hist(target_scores.numpy(), bins=100, alpha=0.5, label='Target',
            weights=np.ones_like(target_scores.numpy()) / total_samples_target)  # Нормализация


    plt.hist(imposter_scores.numpy(), bins=100, alpha=0.5, label='Imposter',
             weights=np.ones_like(imposter_scores.numpy()) / total_samples_imposter)

    plt.axvline(x=eer_threshold.item(), color='red', linestyle='--', linewidth=2, label=f'EER Threshold ({eer_threshold:.2f})')
    plt.xlabel('Score (Euclidean)')
    plt.ylabel('Probability')
    plt.legend()

    if subtitle != None:
        plt.suptitle(subtitle)

    plt.savefig(os.path.join(root, out))

    plt.close()

    print(f"EER: {eer.item():.4f} at threshold: {eer_threshold.item():.4f}")


PATH_PLOT = None
# Define our own plot function
def scatter(x, labels, subtitle=None, root=None):
    global PATH_PLOT
    if root == None:
        if PATH_PLOT == None:
            current_time = datetime.now()
            formated_time = current_time.strftime("%m-%d_%H:%M:%S")
            PATH_PLOT = f"plot_{subtitle[:-1]}_{formated_time}"
            root = PATH_PLOT
            if not os.path.exists(root):
                os.makedirs(root)
        else:
            root = PATH_PLOT
    else:
        PATH_PLOT = root

    unique_labels = np.unique(labels)
    labels = np.searchsorted(unique_labels, labels)  # Переиндексация в 0,1,2,...
    num_classes = len(unique_labels)
    palette = np.array(sns.color_palette("hls", num_classes)) # Choosing color
    # Create a seaborn scatter plot #
    f = plt.figure(figsize=(8, 8))
    ax = plt.subplot(aspect='equal')
    sc = ax.scatter(x[:,0], x[:,1], lw=0, s=40,
                    c=palette[labels.astype(np.int32)])

    plt.xlim(-25, 25)
    plt.ylim(-25, 25)

    ax.axis('off')
    ax.axis('tight')

    # Add label on top of each cluster ##
    idx2name = [str(x+1) for x in range(num_classes)]
    txts = []
    for i in range(num_classes):
        # Position of each label.
        xtext, ytext = np.median(x[labels == i, :], axis=0)
        txt = ax.text(xtext, ytext, idx2name[i], fontsize=24)
        txt.set_path_effects([
            PathEffects.Stroke(linewidth=5, foreground="w"),
            PathEffects.Normal()])
        txts.append(txt)


    if subtitle != None:
        plt.suptitle(subtitle)

    if not os.path.exists(root):
        os.makedirs(root)

    plt.savefig(os.path.join(root, str(subtitle)))
    plt.close()


def zero_shot_inference(sample_embedding, class_embeddings, class_names):
    """
    Оптимизированная версия с использованием встроенных функций PyTorch
    """
    max_similarity = -float('inf')
    predicted_class = None

    for class_name in class_names:
        # Получаем все эмбеддинги класса (тензор [N, D])
        class_embs = class_embeddings[class_name]

        # Вычисляем косинусную схожесть сразу для всех примеров класса
        # Добавляем размерность для батча (из [D] -> [1, D])
        similarities = F.cosine_similarity(
            sample_embedding.unsqueeze(0),  # [1, D]
            class_embs,                    # [N, D]
            dim=1
        )

        # Находим максимальную схожесть для класса
        class_max_sim = torch.max(similarities).item()

        if class_max_sim > max_similarity:
            max_similarity = class_max_sim
            predicted_class = class_name

    return predicted_class


def get_emb(model, dataloader, device):
    """Получение эмбеддингов в виде тензоров PyTorch"""
    model.eval()
    embeddings = []
    labels = []
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs = inputs.to(device)
            emb = model(inputs)
            emb = F.normalize(emb, p=2, dim=1)
            embeddings.append(emb)  # Сохраняем как тензор
            labels.append(targets.to(device))
    return torch.cat(embeddings), torch.cat(labels)


def check(model, model2, test_dl, device):
    embeddings, labels = get_emb(model, test_dl, device)
    x, lab = get_emb(model2, test_dl, device)

    # Исправляем преобразование numpy в тензоры
    class_embeddings = {}
    for emb, label in zip(embeddings, labels):
        label = label.item()
        emb = emb.cpu()  # Переносим на CPU для совместимости

        if label not in class_embeddings:
            class_embeddings[label] = emb.unsqueeze(0)
        else:
            class_embeddings[label] = torch.cat([
                class_embeddings[label],
                emb.unsqueeze(0)
            ], dim=0)

    # Преобразуем все к одному устройству
    device = next(model.parameters()).device
    correct = 0
    for emb, true_label in zip(x, lab):
        emb = emb.to(device)
        true_label = true_label.item()

        # Конвертируем эмбеддинги класса к нужному устройству
        class_embs_on_device = {
            k: v.to(device) for k, v in class_embeddings.items()
        }

        predicted_class = zero_shot_inference(
            emb,
            class_embs_on_device,
            list(class_embeddings.keys())
        )
        correct += (predicted_class == true_label)

    return correct / len(lab)


In [6]:
import torchvision.models as models
from torch import nn


def get_resnet(num_classes=512):
    model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)

    return model

def get_resnet152(num_classes=512):
    model = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)

    return model


In [7]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F


class Hard_mining_TripletLoss(nn.Module):
    def __init__(self, margin=0.5, device=None):
        super().__init__()
        self.margin = margin
        self.device = device

    def _get_anchor_positive_triplet_mask(self, labels):
        """Return a 2D mask where mask[a, p] is True if a and p are distinct and have same label.

        Args:
            labels: tf.int32 `Tensor` with shape [batch_size]

        Returns:
            mask: tf.bool `Tensor` with shape [batch_size, batch_size]
        """

        # Check that i and j are distinct
        indices_equal = torch.eye(labels.size()[0]).bool().to(self.device)
        indices_not_equal = ~indices_equal # flip booleans

        # Check if labels[i] == labels[j]
        # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
        labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)
        # Combine the two masks
        mask = indices_not_equal & labels_equal

        return mask


    def _get_anchor_negative_triplet_mask(self, labels):
        """Return a 2D mask where mask[a, n] is True if a and n have distinct labels.

        Args:
            labels: tf.int32 `Tensor` with shape [batch_size]

        Returns:
            mask: tf.bool `Tensor` with shape [batch_size, batch_size]
        """
        # Check if labels[i] != labels[k]
        # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
        labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)
        mask = ~labels_equal # invert the boolean tensor

        return mask

    def forward(self, embeddings, labels):
        # Get the pairwise distance matrix
        pairwise_dist = torch.cdist(embeddings, embeddings, p=2)

        # For each anchor, get the hardest positive
        # First, we need to get a mask for every valid positive (they should have same label)
        mask_anchor_positive = self._get_anchor_positive_triplet_mask(labels)

        # We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
        anchor_positive_dist = mask_anchor_positive * pairwise_dist

        # shape (batch_size, 1)
        hardest_positive_dist = torch.max(anchor_positive_dist, 1, keepdim=True)[0]

        # For each anchor, get the hardest negative
        # First, we need to get a mask for every valid negative (they should have different labels)
        mask_anchor_negative = self._get_anchor_negative_triplet_mask(labels)

        # We add the maximum value in each row to the invalid negatives (label(a) == label(n))
        max_anchor_negative_dist = torch.max(pairwise_dist, 1, keepdim=True)[0]
        anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * ~(mask_anchor_negative)

        # shape (batch_size,)
        hardest_negative_dist = torch.min(anchor_negative_dist, 1, keepdim=True)[0]

        # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
        triplet_loss = torch.max(hardest_positive_dist - hardest_negative_dist + self.margin, torch.Tensor([0.0]).to(self.device))

        # Get final mean triplet loss
        triplet_loss = torch.mean(triplet_loss)

        return triplet_loss


class BatchHardTripletLoss(nn.Module):
    def __init__(self, margin=0.5, device=None):
        super().__init__()
        self.margin = margin
        self.device = device

    def get_anchor_positive_mask(self, labels):
        indices_eq = torch.eq(labels.unsqueeze(0), labels.unsqueeze(1))
        identity = torch.eye(labels.size(0), dtype=torch.bool, device=self.device)
        return indices_eq & ~identity

    def get_anchor_negative_mask(self, labels):
        return torch.ne(labels.unsqueeze(0), labels.unsqueeze(1))

    def forward(self, embeddings, labels):
        pairwise_dist = torch.cdist(embeddings, embeddings, p=2)

        mask_anchor_positive = self.get_anchor_positive_mask(labels)
        mask_anchor_negative = self.get_anchor_negative_mask(labels)

        # Hardest positive
        anchor_positive_dist = pairwise_dist * mask_anchor_positive.float()
        hardest_positive_dist, _ = anchor_positive_dist.max(dim=1, keepdim=True)

        # Hardest negative
        anchor_negative_dist = pairwise_dist * mask_anchor_negative.float()
        hardest_negative_dist, _ = anchor_negative_dist.min(dim=1, keepdim=True)

        # Расчет Triplet Loss
        triplet_loss = F.relu(
            hardest_positive_dist - hardest_negative_dist + self.margin
        )

        # Усреднение с учетом валидных триплетов
        valid_mask = triplet_loss > 1e-16
        loss = triplet_loss[valid_mask].sum() / (valid_mask.sum().float() + 1e-16)

        return loss if not torch.isnan(loss) else torch.tensor(0.0, device=self.device)


class TripletLoss(nn.Module):
    def __init__(self, margin, device=None):
        super().__init__()
        self.margin = margin
        self.device = device

    def forward(self, fs, ys):
        anchor, positive, negative = fs
        anchor = anchor.to(self.device)
        positive = positive.to(self.device)
        negative = negative.to(self.device)

        dist_pos = (anchor - positive).square().sum(axis=-1)
        dist_neg = (anchor - negative).square().sum(axis=-1)

        loss = F.relu(dist_pos - dist_neg + self.margin)

        return loss.mean()



In [8]:
!gdown --id 1Kc4iP695Ggh_OkDGtIYNzu26jhv58nTA

Downloading...
From (original): https://drive.google.com/uc?id=1Kc4iP695Ggh_OkDGtIYNzu26jhv58nTA
From (redirected): https://drive.google.com/uc?id=1Kc4iP695Ggh_OkDGtIYNzu26jhv58nTA&confirm=t&uuid=254cc0e8-4ebb-4095-914e-79e142174823
To: /kaggle/working/norm_photo.zip
100%|██████████████████████████████████████| 79.9M/79.9M [00:02<00:00, 37.6MB/s]


In [11]:
!unzip -q norm_photo.zip -d /content/data

In [12]:
!ls /content/data

Norm_photo


In [43]:
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision.datasets import ImageFolder
import torchvision
import torch.nn.functional as F
from torch.nn.functional import normalize
from torchvision import datasets
from torch import nn
import torch.optim as optim
import torch
from torchvision.transforms import transforms
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
import os
from sklearn.manifold import TSNE


## -------------------------------------------------------------- ##


torch.manual_seed(42)

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

NETWORK = False    # Do you want to use new Network? [True] or backbone (resnet) [False]
SUBTITLE = None

#hyper parameters:
num_epochs = 50
PATH_TO_NORMALIZE_PHOTOES = "/content/data/Norm_photo"
NUM_SEEN_CLASSES = 180
batch_size = 64
margin = 0.6
learning_rate = 0.0001


original_model_parameters = {
    'in_channels' : 2048,
    'num_parts' : 4,
    'gamma' : 0.05,
    'num_classes' : 512
}

PATH_PLOT = None
ROOT_HIST = None
if __name__ == '__main__':

    if NETWORK:
        model = Triplet_Network(original_model_parameters).to(device)
        SUBTITLE = "Phase Model"
    else:
        model = get_resnet(original_model_parameters['num_classes']).to(device)
        SUBTITLE = "ResNet50"


    train_dataloader, few_dataloader, test_dataloader = get_dl_2_IRIS(
        PATH_TO_NORMALIZE_PHOTOES,
        NUM_SEEN_CLASSES,
        batch_size,
        norm_transform
    )

    criterion = Hard_mining_TripletLoss(margin=margin, device=device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)

    val_losses = []

    best_auc = 0.0
    train_losses = []
    val_metrics = {'FAR': [], 'FRR': [], 'ERR': [], 'AUC': []}
    eer_list = []
    val_accuracies = []
    
    print(f"Starts fiting the model, Margin = {margin}")
    for e in range(num_epochs):
        model = model.train()

        with tqdm(train_dataloader,
                desc=f"Epoch {e+1}/{num_epochs} [Train]",
                leave=False,
                dynamic_ncols=True
        ) as pbar:
            train_loss = 0.0
            for batch in pbar:
                images, labels = batch
                images = images.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                embeddings = model(images)

                loss = criterion(embeddings, labels)

                loss.backward()

                optimizer.step()

                train_loss += loss.item() * images.shape[0]

                pbar.set_postfix({'loss': f"{loss.item():.4f}"})

        train_loss /= len(train_dataloader.dataset)
        train_losses.append(train_loss)

        print(f"Epoch {e+1}, Loss: {train_loss:.4f} \n LR: {optimizer.param_groups[0]['lr']}")

        model.eval()
        #print(f"Result : {check(resnet50, model, test_dataloader, device) * 100:.4f} %")
        acc, _, __ = testing_model(model, few_dataloader, test_dataloader, device=device)
        val_accuracies.append(acc)
        
        embeddings, labels = get_embeddings(model, test_dataloader, device=device)
        metrics = calculate_eer(embeddings, labels)

        eer_list.append(metrics['eer'])
        
        print(f"""
                EER: {metrics['eer']:.7f}
                Threshold: {metrics['threshold']:.4f}
                Genuine distances: {metrics['genuine_mean']:.2f} ± {metrics['genuine_std']:.2f}
                Impostor distances: {metrics['impostor_mean']:.2f} ± {metrics['impostor_std']:.2f}
        """)

        if e % 5 == 0:
            tsne = TSNE(random_state=0, perplexity=20)
            embeddings, labels = get_embeddings(model, test_dataloader, device=device)
            train_tsne_embeds = tsne.fit_transform(embeddings)

            get_hist(model, test_dataloader, device, f"hist_{e}.png", SUBTITLE)

            scatter(train_tsne_embeds, labels.astype(np.int32), SUBTITLE + str(e))


        scheduler.step(train_loss)



Starts fiting the model, Margin = 0.6


                                                                                

Epoch 1, Loss: 0.1249 
 LR: 0.0001




Accuracy = 28.87% (125 / 433)

                EER: 0.2427795
                Threshold: -0.7323
                Genuine distances: 0.67 ± 0.10
                Impostor distances: 0.79 ± 0.09
        
EER: 0.2438 at threshold: 0.7329


                                                                                

Epoch 2, Loss: 0.1014 
 LR: 0.0001




Accuracy = 30.95% (134 / 433)

                EER: 0.2139844
                Threshold: -0.7266
                Genuine distances: 0.66 ± 0.10
                Impostor distances: 0.81 ± 0.11
        


                                                                                

Epoch 3, Loss: 0.0985 
 LR: 0.0001




Accuracy = 42.26% (183 / 433)

                EER: 0.1770224
                Threshold: -0.8939
                Genuine distances: 0.75 ± 0.16
                Impostor distances: 1.04 ± 0.16
        


                                                                                

Epoch 4, Loss: 0.1002 
 LR: 0.0001




Accuracy = 48.96% (212 / 433)

                EER: 0.1502663
                Threshold: -0.9860
                Genuine distances: 0.81 ± 0.17
                Impostor distances: 1.16 ± 0.17
        


                                                                                

Epoch 5, Loss: 0.0886 
 LR: 0.0001




Accuracy = 52.19% (226 / 433)

                EER: 0.1451905
                Threshold: -1.0252
                Genuine distances: 0.78 ± 0.21
                Impostor distances: 1.25 ± 0.21
        


                                                                                

Epoch 6, Loss: 0.0901 
 LR: 0.0001




Accuracy = 55.66% (241 / 433)

                EER: 0.1196707
                Threshold: -1.0581
                Genuine distances: 0.78 ± 0.22
                Impostor distances: 1.31 ± 0.20
        
EER: 0.1195 at threshold: 1.0574


                                                                                

Epoch 7, Loss: 0.0796 
 LR: 0.0001




Accuracy = 62.12% (269 / 433)

                EER: 0.0960381
                Threshold: -1.0435
                Genuine distances: 0.75 ± 0.22
                Impostor distances: 1.33 ± 0.21
        


                                                                                

Epoch 8, Loss: 0.0644 
 LR: 0.0001




Accuracy = 65.36% (283 / 433)

                EER: 0.1127837
                Threshold: -1.0083
                Genuine distances: 0.70 ± 0.24
                Impostor distances: 1.29 ± 0.22
        


                                                                                

Epoch 9, Loss: 0.0653 
 LR: 0.0001




Accuracy = 67.44% (292 / 433)

                EER: 0.0989881
                Threshold: -1.0581
                Genuine distances: 0.72 ± 0.24
                Impostor distances: 1.33 ± 0.21
        


                                                                                 

Epoch 10, Loss: 0.0528 
 LR: 0.0001




Accuracy = 65.82% (285 / 433)

                EER: 0.0976432
                Threshold: -1.0386
                Genuine distances: 0.70 ± 0.25
                Impostor distances: 1.33 ± 0.21
        


                                                                                 

Epoch 11, Loss: 0.0431 
 LR: 0.0001




Accuracy = 67.67% (293 / 433)

                EER: 0.0866566
                Threshold: -1.0233
                Genuine distances: 0.67 ± 0.23
                Impostor distances: 1.32 ± 0.21
        
EER: 0.0865 at threshold: 1.0224


                                                                                 

Epoch 12, Loss: 0.0458 
 LR: 0.0001




Accuracy = 68.82% (298 / 433)

                EER: 0.0956910
                Threshold: -1.0004
                Genuine distances: 0.67 ± 0.24
                Impostor distances: 1.29 ± 0.21
        


                                                                                 

Epoch 13, Loss: 0.0358 
 LR: 0.0001




Accuracy = 70.21% (304 / 433)

                EER: 0.0692277
                Threshold: -1.0141
                Genuine distances: 0.65 ± 0.23
                Impostor distances: 1.35 ± 0.21
        


                                                                                 

Epoch 14, Loss: 0.0328 
 LR: 0.0001




Accuracy = 69.75% (302 / 433)

                EER: 0.0775463
                Threshold: -1.0093
                Genuine distances: 0.65 ± 0.22
                Impostor distances: 1.33 ± 0.21
        


                                                                                 

Epoch 15, Loss: 0.0281 
 LR: 0.0001




Accuracy = 72.75% (315 / 433)

                EER: 0.0695314
                Threshold: -1.0069
                Genuine distances: 0.63 ± 0.22
                Impostor distances: 1.35 ± 0.22
        


                                                                                 

Epoch 16, Loss: 0.0260 
 LR: 0.0001




Accuracy = 74.13% (321 / 433)

                EER: 0.0737503
                Threshold: -1.0452
                Genuine distances: 0.66 ± 0.23
                Impostor distances: 1.36 ± 0.20
        
EER: 0.0732 at threshold: 1.0455


                                                                                 

Epoch 17, Loss: 0.0272 
 LR: 0.0001




Accuracy = 74.83% (324 / 433)

                EER: 0.0673405
                Threshold: -1.0497
                Genuine distances: 0.66 ± 0.24
                Impostor distances: 1.38 ± 0.20
        


                                                                                 

Epoch 18, Loss: 0.0244 
 LR: 0.0001




Accuracy = 78.52% (340 / 433)

                EER: 0.0646725
                Threshold: -1.0230
                Genuine distances: 0.63 ± 0.23
                Impostor distances: 1.37 ± 0.20
        


                                                                                 

Epoch 19, Loss: 0.0223 
 LR: 0.0001




Accuracy = 75.06% (325 / 433)

                EER: 0.0667874
                Threshold: -1.0237
                Genuine distances: 0.63 ± 0.23
                Impostor distances: 1.35 ± 0.20
        


                                                                                 

Epoch 20, Loss: 0.0225 
 LR: 0.0001




Accuracy = 77.60% (336 / 433)

                EER: 0.0685878
                Threshold: -1.0453
                Genuine distances: 0.64 ± 0.24
                Impostor distances: 1.35 ± 0.19
        


                                                                                 

Epoch 21, Loss: 0.0237 
 LR: 0.0001




Accuracy = 76.44% (331 / 433)

                EER: 0.0743468
                Threshold: -1.0300
                Genuine distances: 0.63 ± 0.24
                Impostor distances: 1.34 ± 0.20
        
EER: 0.0746 at threshold: 1.0306


                                                                                 

Epoch 22, Loss: 0.0174 
 LR: 0.0001




Accuracy = 78.75% (341 / 433)

                EER: 0.0665922
                Threshold: -1.0317
                Genuine distances: 0.64 ± 0.23
                Impostor distances: 1.35 ± 0.20
        


                                                                                 

Epoch 23, Loss: 0.0154 
 LR: 0.0001




Accuracy = 77.83% (337 / 433)

                EER: 0.0623516
                Threshold: -1.0017
                Genuine distances: 0.62 ± 0.22
                Impostor distances: 1.34 ± 0.20
        


                                                                                 

Epoch 24, Loss: 0.0168 
 LR: 0.0001




Accuracy = 78.29% (339 / 433)

                EER: 0.0735334
                Threshold: -1.0148
                Genuine distances: 0.62 ± 0.24
                Impostor distances: 1.34 ± 0.21
        


                                                                                 

Epoch 25, Loss: 0.0135 
 LR: 0.0001




Accuracy = 79.45% (344 / 433)

                EER: 0.0631975
                Threshold: -1.0458
                Genuine distances: 0.62 ± 0.24
                Impostor distances: 1.38 ± 0.20
        


                                                                                 

Epoch 26, Loss: 0.0130 
 LR: 0.0001




Accuracy = 80.60% (349 / 433)

                EER: 0.0555188
                Threshold: -1.0600
                Genuine distances: 0.64 ± 0.24
                Impostor distances: 1.39 ± 0.19
        
EER: 0.0556 at threshold: 1.0595


                                                                                 

Epoch 27, Loss: 0.0122 
 LR: 0.0001




Accuracy = 77.83% (337 / 433)

                EER: 0.0665813
                Threshold: -1.0428
                Genuine distances: 0.63 ± 0.24
                Impostor distances: 1.36 ± 0.20
        


                                                                                 

Epoch 28, Loss: 0.0085 
 LR: 0.0001




Accuracy = 80.60% (349 / 433)

                EER: 0.0518096
                Threshold: -1.0206
                Genuine distances: 0.62 ± 0.22
                Impostor distances: 1.37 ± 0.19
        


                                                                                 

Epoch 29, Loss: 0.0096 
 LR: 0.0001




Accuracy = 78.75% (341 / 433)

                EER: 0.0594991
                Threshold: -1.0414
                Genuine distances: 0.63 ± 0.23
                Impostor distances: 1.37 ± 0.19
        


                                                                                 

Epoch 30, Loss: 0.0076 
 LR: 0.0001




Accuracy = 81.99% (355 / 433)

                EER: 0.0578506
                Threshold: -1.0389
                Genuine distances: 0.63 ± 0.23
                Impostor distances: 1.36 ± 0.19
        


                                                                                 

Epoch 31, Loss: 0.0096 
 LR: 0.0001




Accuracy = 80.83% (350 / 433)

                EER: 0.0555188
                Threshold: -1.0236
                Genuine distances: 0.62 ± 0.23
                Impostor distances: 1.37 ± 0.20
        
EER: 0.0557 at threshold: 1.0235


                                                                                 

Epoch 32, Loss: 0.0083 
 LR: 0.0001




Accuracy = 79.21% (343 / 433)

                EER: 0.0624058
                Threshold: -1.0152
                Genuine distances: 0.61 ± 0.24
                Impostor distances: 1.34 ± 0.19
        


                                                                                 

Epoch 33, Loss: 0.0056 
 LR: 0.0001




Accuracy = 80.14% (347 / 433)

                EER: 0.0670477
                Threshold: -1.0492
                Genuine distances: 0.62 ± 0.26
                Impostor distances: 1.36 ± 0.20
        


                                                                                 

Epoch 34, Loss: 0.0076 
 LR: 0.0001




Accuracy = 82.45% (357 / 433)

                EER: 0.0618093
                Threshold: -1.0323
                Genuine distances: 0.62 ± 0.24
                Impostor distances: 1.36 ± 0.19
        


                                                                                 

Epoch 35, Loss: 0.0047 
 LR: 0.0001




Accuracy = 82.68% (358 / 433)

                EER: 0.0496079
                Threshold: -1.0164
                Genuine distances: 0.61 ± 0.23
                Impostor distances: 1.36 ± 0.19
        


                                                                                 

Epoch 36, Loss: 0.0061 
 LR: 0.0001




Accuracy = 79.68% (345 / 433)

                EER: 0.0598462
                Threshold: -1.0314
                Genuine distances: 0.62 ± 0.23
                Impostor distances: 1.36 ± 0.19
        
EER: 0.0600 at threshold: 1.0314


                                                                                 

Epoch 37, Loss: 0.0062 
 LR: 0.0001




Accuracy = 82.45% (357 / 433)

                EER: 0.0666356
                Threshold: -1.0139
                Genuine distances: 0.61 ± 0.24
                Impostor distances: 1.33 ± 0.19
        


                                                                                 

Epoch 38, Loss: 0.0062 
 LR: 0.0001




Accuracy = 80.83% (350 / 433)

                EER: 0.0549223
                Threshold: -1.0422
                Genuine distances: 0.63 ± 0.24
                Impostor distances: 1.37 ± 0.18
        


                                                                                 

Epoch 39, Loss: 0.0057 
 LR: 1e-05




Accuracy = 82.91% (359 / 433)

                EER: 0.0554212
                Threshold: -1.0516
                Genuine distances: 0.63 ± 0.23
                Impostor distances: 1.37 ± 0.18
        


                                                                                 

Epoch 40, Loss: 0.0041 
 LR: 1e-05




Accuracy = 81.52% (353 / 433)

                EER: 0.0596727
                Threshold: -1.0611
                Genuine distances: 0.64 ± 0.24
                Impostor distances: 1.37 ± 0.18
        


                                                                                 

Epoch 41, Loss: 0.0025 
 LR: 1e-05




Accuracy = 82.45% (357 / 433)

                EER: 0.0568528
                Threshold: -1.0534
                Genuine distances: 0.63 ± 0.23
                Impostor distances: 1.37 ± 0.18
        
EER: 0.0570 at threshold: 1.0547


                                                                                 

Epoch 42, Loss: 0.0039 
 LR: 1e-05




Accuracy = 81.76% (354 / 433)

                EER: 0.0574385
                Threshold: -1.0515
                Genuine distances: 0.63 ± 0.24
                Impostor distances: 1.37 ± 0.18
        


                                                                                 

Epoch 43, Loss: 0.0044 
 LR: 1e-05




Accuracy = 83.37% (361 / 433)

                EER: 0.0531328
                Threshold: -1.0491
                Genuine distances: 0.63 ± 0.23
                Impostor distances: 1.38 ± 0.18
        


                                                                                 

Epoch 44, Loss: 0.0025 
 LR: 1e-05




Accuracy = 82.45% (357 / 433)

                EER: 0.0562129
                Threshold: -1.0456
                Genuine distances: 0.63 ± 0.23
                Impostor distances: 1.37 ± 0.18
        


                                                                                 

Epoch 45, Loss: 0.0031 
 LR: 1e-05




Accuracy = 82.22% (356 / 433)

                EER: 0.0574927
                Threshold: -1.0559
                Genuine distances: 0.63 ± 0.23
                Impostor distances: 1.37 ± 0.18
        


                                                                                 

Epoch 46, Loss: 0.0012 
 LR: 1e-05




Accuracy = 81.99% (355 / 433)

                EER: 0.0551067
                Threshold: -1.0548
                Genuine distances: 0.63 ± 0.23
                Impostor distances: 1.37 ± 0.18
        
EER: 0.0553 at threshold: 1.0550


                                                                                 

Epoch 47, Loss: 0.0025 
 LR: 1e-05




Accuracy = 82.45% (357 / 433)

                EER: 0.0545427
                Threshold: -1.0480
                Genuine distances: 0.63 ± 0.23
                Impostor distances: 1.37 ± 0.18
        


                                                                                 

Epoch 48, Loss: 0.0018 
 LR: 1e-05




Accuracy = 82.91% (359 / 433)

                EER: 0.0513866
                Threshold: -1.0447
                Genuine distances: 0.62 ± 0.23
                Impostor distances: 1.38 ± 0.18
        


                                                                                 

Epoch 49, Loss: 0.0032 
 LR: 1e-05




Accuracy = 80.83% (350 / 433)

                EER: 0.0569070
                Threshold: -1.0566
                Genuine distances: 0.63 ± 0.24
                Impostor distances: 1.38 ± 0.18
        


                                                                                 

Epoch 50, Loss: 0.0021 
 LR: 1.0000000000000002e-06




Accuracy = 81.76% (354 / 433)

                EER: 0.0548464
                Threshold: -1.0528
                Genuine distances: 0.62 ± 0.23
                Impostor distances: 1.38 ± 0.18
        


In [51]:
sns.set_theme(style="whitegrid", context="talk", palette="colorblind")
plt.rcParams['font.family'] = 'DejaVu Sans'

        # 1. Loss train
plt.figure(figsize=(13, 6))
sns.lineplot(
            x=range(len(train_losses)),
            y=train_losses,
            #label='Train Loss',
            linewidth=2.5,
            marker='o',
            markersize=8
)


plt.title('Training Loss', fontsize=16, pad=20)
plt.xlabel('Epoch', fontsize=14)
plt.ylabel('Loss', fontsize=14)
#plt.legend(title='', frameon=True, facecolor='white')
plt.tight_layout()
plt.suptitle(SUBTITLE)

plt.savefig(f"./loss_output.svg", dpi=300, format="svg", bbox_inches='tight')
plt.close()

# 2. Accuracy
plt.figure(figsize=(13, 6))
palette = sns.color_palette("husl", 2)

eer_list2 = eer_list
sns.lineplot(
            x=range(len(eer_list)),
            y=eer_list,
            color=palette[0],
            linewidth=2.5,
            marker='^',
            markersize=8
)

plt.title('ERR', fontsize=16, pad=20)
plt.xlabel('Epoch', fontsize=14)
plt.ylabel('ERR', fontsize=14)
#plt.legend(title='', frameon=True, facecolor='white')
plt.tight_layout()
plt.suptitle(SUBTITLE)
plt.savefig(f"./eer_output.svg", dpi=300, format="svg", bbox_inches='tight')
plt.close()


# 2. Accuracy
plt.figure(figsize=(13, 6))
palette = sns.color_palette("husl", 2)

val_accur2 = val_accuracies
sns.lineplot(
            x=range(len(val_accuracies)),
            y=[100 * i for i in val_accuracies],
            #label='TOP 1 ACC',
            color=palette[0],
            linewidth=2.5,
            marker='^',
            markersize=8
)

plt.title('Validation Accuracy', fontsize=16, pad=20)
plt.xlabel('Epoch', fontsize=14)
plt.ylabel('Accuracy (%)', fontsize=14)
#plt.legend(title='', frameon=True, facecolor='white')
plt.tight_layout()
plt.suptitle(SUBTITLE)
plt.savefig(f"./accuracy_output.svg", dpi=300, format="svg", bbox_inches='tight')
plt.close()
np.savez(f'./array_losses.npz',
                val_accuracies=val_accuracies)

#np.savez(f'{self.dir_plot}/array_accuracy.npz',
#                val_accuracies_top1=self.val_accuracies_top1,
#                val_accuracies_top5=self.val_accuracies_top5)

  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
