In [None]:
from sklearn.metrics import roc_curve, det_curve, DetCurveDisplay
import matplotlib.pyplot as plt
import torch

import numpy as np
np.set_printoptions(suppress=True)

from pathlib import Path
import concurrent.futures
from functools import partial

In [None]:
thresholds = torch.linspace(-3, -0.01, 300)
def compute_model(exp_dir: Path, n_episodes: int, n_labels: int, device: str = 'cuda:0'):
    def compute_ep(ep):
        results = torch.load(exp_dir / f'ep{ep}.pt').to(device)
        labels = results[:, 0].to(int)
        scores = results[:, 1:]

        y_true = labels == torch.arange(n_labels).view(-1,1).to(device)
        y_pred = scores.T > thresholds.view(-1,1,1).to(device)


        fpr = (~y_true & y_pred).sum(-1) / (~y_true).sum(1)
        fnr = (y_true & ~y_pred).sum(-1) / y_true.sum(1)

        fpr = fpr.mean(1)
        fnr = fnr.mean(1)

        return fpr, fnr


    FPR = FNR = 0
    for ep in range(n_episodes):
        fpr, fnr = compute_ep(ep)
        FPR += fpr
        FNR += fnr

    FPR = (FPR/n_episodes).cpu().numpy()
    FNR = (FNR/n_episodes).cpu().numpy()
    return FPR, FNR


In [None]:
log_ticks = np.array([1, 2, 5, 10, 20, 50, 80, 100])
log_positions = np.linspace(0, 1, len(log_ticks))

def plot_det(fpr, fnr, label: str):
    fpr = np.interp(fpr*100, log_ticks, log_positions)
    fnr = np.interp(fnr*100, log_ticks, log_positions)
    plt.gca().plot(fpr, fnr, label=label)


    log_labels = [f'{t}%' for t in log_ticks]
    plt.xticks(log_positions, log_labels)
    plt.yticks(log_positions, log_labels)


In [None]:
%%time
exp_dir = 'eval_fsl.en263.ncm.5pos.50neg.1shot.100eps'
model_dir = {
    'BCResNet': Path('results/TL_EN_MSWC500U_BCResNet_80W_20Q_400E_40Epoch') / exp_dir,
    'ConvNextV2': Path('results/TL_MSWC500U_ConvNeXtV2_Atto_Tung') / exp_dir,
    'ResNet15': Path('results/EN_MSWC500U_Resnet15_1x49x10_to_64_80W_20Q_400E_40Epoch') / exp_dir,
    'DSCNN_PN': Path('results/PN_MSWC500U_DSCNNLLN') / exp_dir,
    'DSCNN_TL': Path('results/TL_MSWC500U_DSCNNLLN') / exp_dir,
}
device = 'cuda:1'
n_episodes = 100 # ensure all models eval on this many episodes
n_labels = 5

with concurrent.futures.ProcessPoolExecutor() as exe:
    tasks = [
        exe.submit(compute_model, exp_dir, n_episodes, n_labels, device)
        for exp_dir in model_dir.values()
    ]
    results = [task.result() for task in tasks]



In [None]:
plt.figure(figsize=(13,13))

for model, (fpr, fnr) in zip(model_dir, results):
    plot_det(fpr, fnr, model)


plt.title("Detection Error Tradeoff (DET) curves", pad=20, fontsize=15)
plt.xlabel("False Acceptance Rate (in %)", labelpad=10, fontsize=12)
plt.ylabel("False Rejection Rate (in %)", labelpad=10, fontsize=12)

plt.legend()
plt.grid()
plt.savefig('det.en263_1.5pos.50neg.1shot.100eps.pdf')