In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

from importlib import reload

import models
import plotting
import dataloaders as dl
import traintest as tt

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [3]:
base_model = torch.load('SavedModels/base_model.pth').to(device)
gmm_model = torch.load('SavedModels/gmm_model.pth').to(device)

In [164]:
def test_metrics(model, device, in_loader, out_loader, thresholds=torch.linspace(.1, 1., 1000).to(device)):
    with torch.no_grad():
        model.eval()
        conf_in = []
        conf_out = []
        for ((batch_idx, (data_in, _)), (_, (data_out, _))) in zip(enumerate(in_loader),enumerate(out_loader)):
            data_in = data_in.to(device)
            data_out = data_out.to(device)

            output_in = model(data_in).max(1)[0].exp()
            output_out = model(data_out).max(1)[0].exp()

            conf_in.append(output_in)
            conf_out.append(output_out)
        conf_in = torch.cat(conf_in)
        conf_out = torch.cat(conf_out)
        
        l = min(len(in_loader.dataset), len(out_loader.dataset))
        tp = (conf_in[:,None] > thresholds[None,:]).sum(0).float()/l
        fp = (conf_out[:,None] > thresholds[None,:]).sum(0).float()/l
        
        mmc = conf_out.mean()
        auroc = -np.trapz(tp.cpu().numpy(), x=fp.cpu().numpy())
        fp95 = (conf_out > 0.95).sum().float()/l
        return mmc, auroc, fp95

In [165]:
mmc, auroc, fp95 = test_metrics(base_model, device, dl.test_loader, dl.EMNIST_test_loader)

In [166]:
mmc

tensor(0.7871, device='cuda:0')

In [167]:
auroc

0.8842933

In [168]:
fp95

tensor(0.3304, device='cuda:0')