# Anomaly Score comparison

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join(".."))
sys.path.append(module_path)
import torch
from torchvision import datasets

from anomaly_scores.energy import energy_anomaly_score
from anomaly_scores.max_logit import max_logit_anomaly_score
from anomaly_scores.softmax import max_softmax_anomaly_score
from anomaly_scores.vim_scores import VIM
from energy_ood.CIFAR.models.wrn import WideResNet
from energy_ood.utils.svhn_loader import SVHN
from util import TEST_TRANSFORM
from util.display_results import compare_all_results
from util.get_ood_score import get_ood_score_for_multiple_datasets
from vim_training.test import test

## The data
Let's start with replicating the results from the paper. First, with the SVHN data set. 

In [2]:
loaders = []

id_data = datasets.CIFAR10("../data/cifar10", train=False, transform=TEST_TRANSFORM)
id_loader = torch.utils.data.DataLoader(
    id_data, batch_size=200, shuffle=False, num_workers=2, pin_memory=True
)
loaders.append(("CIFAR10", id_loader))


ood_data = SVHN(
    root="../data/svhn/",
    split="test",
    transform=TEST_TRANSFORM,
    download=False,
)
ood_loader = torch.utils.data.DataLoader(
    ood_data, batch_size=200, shuffle=True, num_workers=2, pin_memory=True
)
ood_num_examples = len(loaders[0][1].dataset) // 5
loaders.append(("SVHN", ood_loader))


data = datasets.CIFAR100("../data/cifar-100", train=False, transform=TEST_TRANSFORM)
loader = torch.utils.data.DataLoader(
    data, batch_size=200, shuffle=True, num_workers=2, pin_memory=True
)
loaders.append(("CIFAR100", loader))

## Models

We are using the Wide ResNet as in the paper.

In [3]:
models = []

model_folder = "../snapshots/pretrained/"

for filename in next(os.walk(model_folder), (None, None, []))[2]:
    if "WRN_Hendrycks_Seed" in filename:
        model_name = filename.split(".")[0].split("_")[-1]
        print(model_name)
        model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.3)
        model.load_state_dict(torch.load(model_folder + filename))
        model.eval()
        _ = model.cuda()
        models.append((model_name, model))

Seed1
Seed64
Seed42


In [4]:
filename

'WRN_Hendrycks_Seed42.pt'

# Anomaly Scores
Let's compare the scores.

In [5]:
import numpy as np

aurocs_results = {}
auprs_results = {}
all_results = []
for model_name, model in models:
    print(model_name)
    aurocs_results[model_name] = {}
    auprs_results[model_name] = {}
    vim = VIM(id_loader, model)

    scores = [
        ("MaxLogit", max_logit_anomaly_score),
        ("MaxSoftmax", max_softmax_anomaly_score),
        ("Energy", energy_anomaly_score),
        ("VIM", vim.compute_anomaly_score),
    ]

    _, test_accuracy = test(model, loaders[0][1])
    aurocs_results[model_name]["test_acc"] = test_accuracy
    auprs_results[model_name]["test_acc"] = test_accuracy
    model_results = []

    for score_name, score in scores:
        print("  ", score_name)
        results = get_ood_score_for_multiple_datasets(
            loaders,
            model,
            score,
            is_using="last" if not score_name == "VIM" else "last_penultimate",
            runs=3,
        )
        aurocs = [np.mean(aurocs) for aurocs, _, _ in results]
        aurocs.append(np.mean(aurocs))
        aurocs_results[model_name][score_name] = aurocs
        auprs = [np.mean(auprs) for _, auprs, _ in results]
        auprs.append(np.mean(auprs))
        auprs_results[model_name][score_name] = auprs
        model_results.append(results)
    all_results.append(model_results)

Seed1
   MaxLogit
   MaxSoftmax
   Energy
   VIM
Seed64
   MaxLogit
   MaxSoftmax
   Energy
   VIM
Seed42
   MaxLogit
   MaxSoftmax
   Energy
   VIM


In [6]:
compare_all_results(aurocs_results, loaders)

            Seed1 (5.15%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    91.26%    |    87.51%    |    89.39%   
               MaxSoftmax |    92.08%    |   *87.98%    |    90.03%   
                   Energy |    91.17%    |    86.93%    |    89.05%   
                      VIM |   *95.00%    |    85.74%    |   *90.37%   

           Seed64 (5.53%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |   *93.95%    |    86.51%    |    90.23%   
               MaxSoftmax |    93.54%    |   *87.49%    |   *90.51%   
                   Energy |    93.76%    |    86.28%    |    90.02%   
                      VIM |    92.74%    |    86.26%    |    89.50%   

           Seed42 (5.36%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    88.88%    |    87.19%    |    88.04%   
               MaxSoftmax |    88.76%    |   *87.76%    |   *88.26%   
                   Energy |   *89.30%    |    87.02%    |    88.16%   
    

In [7]:
compare_all_results(auprs_results, loaders)

            Seed1 (5.15%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    91.27%    |   *87.02%    |    89.14%   
               MaxSoftmax |    90.52%    |    85.69%    |    88.10%   
                   Energy |    91.18%    |    86.74%    |    88.96%   
                      VIM |   *94.40%    |    84.37%    |   *89.39%   

           Seed64 (5.53%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    91.50%    |    85.48%    |    88.49%   
               MaxSoftmax |    90.10%    |    84.30%    |    87.20%   
                   Energy |   *91.52%    |   *85.50%    |   *88.51%   
                      VIM |    90.95%    |    85.11%    |    88.03%   

           Seed42 (5.36%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |   *87.18%    |   *86.24%    |   *86.71%   
               MaxSoftmax |    85.44%    |    84.85%    |    85.15%   
                   Energy |    87.17%    |    86.14%    |    86.65%   
    

# Latex Prepartion

In [8]:
import numpy as np
from energy_ood.utils.display_results import get_measures

from util.get_ood_score import get_ood_scores

RUNS = 10
all_results = {}
for model_name, model in models:
    print(model_name)

    vim = VIM(id_loader, model)
    scores = [
        ("MaxLogit", max_logit_anomaly_score),
        ("MaxSoftmax", max_softmax_anomaly_score),
        ("Energy", energy_anomaly_score),
        ("VIM", vim.compute_anomaly_score),
        # ("Test Error", None),
    ]

    # _, test_accuracy = test(model, loaders[0][1])

    ood_num_examples = len(loaders[0][1].dataset) // 5
    all_results[model_name] = {}
    for score_name, score in scores:
        print("  ", score_name)
        in_score = get_ood_scores(
            loaders[0][1],
            model,
            score,
            ood_num_examples,
            in_dist=True,
            is_using="last" if not score_name == "VIM" else "last_penultimate",
        )
        for ds_name, loader in loaders[1:]:
            if ds_name not in all_results[model_name].keys():
                all_results[model_name][ds_name] = {}
            print("    ", ds_name)
            all_results[model_name][ds_name][score_name] = {
                "AUROC": [],
                "AUPR": [],
                "FPR": [],
            }
            for _ in range(RUNS):
                out_score = get_ood_scores(
                    loader,
                    model,
                    score,
                    ood_num_examples,
                    is_using="last" if not score_name == "VIM" else "last_penultimate",
                )
                auroc, aupr, fpr = get_measures(out_score[:], in_score[:])
                all_results[model_name][ds_name][score_name]["AUROC"].append(auroc)
                all_results[model_name][ds_name][score_name]["AUPR"].append(aupr)
                all_results[model_name][ds_name][score_name]["FPR"].append(fpr)

Seed1
   MaxLogit
     SVHN
     CIFAR100
   MaxSoftmax
     SVHN
     CIFAR100
   Energy
     SVHN
     CIFAR100
   VIM
     SVHN
     CIFAR100
Seed64
   MaxLogit
     SVHN
     CIFAR100
   MaxSoftmax
     SVHN
     CIFAR100
   Energy
     SVHN
     CIFAR100
   VIM
     SVHN
     CIFAR100
Seed42
   MaxLogit
     SVHN
     CIFAR100
   MaxSoftmax
     SVHN
     CIFAR100
   Energy
     SVHN
     CIFAR100
   VIM
     SVHN
     CIFAR100


In [9]:
HAS_STD = False

In [10]:
mean_std_results = []
for i in all_results.keys():
    res = []
    for j in all_results[i].keys():
        for k in all_results[i][j].keys():
            for l in all_results[i][j][k].keys():
                res.append(np.mean(all_results[i][j][k][l]))
                if HAS_STD:
                    res.append(np.std(all_results[i][j][k][l]))
    mean_std_results.append(res)

In [11]:
np.array(mean_std_results).shape

(3, 24)

In [12]:
import pandas as pd

col_stats = ["Mean"]
if HAS_STD:
    col_stats.append["Std"]
cols = pd.MultiIndex.from_tuples(
    [
        (j, k, l, m)
        for i in [next(iter(all_results))]
        for j in all_results[i].keys()
        for k in all_results[i][j].keys()
        for l in all_results[i][j][k].keys()
        for m in col_stats
    ],
    names=["Dataset", "Score", "Measure", "Statistics"],
)
df = pd.DataFrame(np.array(mean_std_results)*100, columns=cols, index=all_results.keys()
)
stats = pd.DataFrame()
stats["Mean"] = df.mean(axis=0)
stats["Std"] = df.std(axis=0)
(df.style.highlight_max(axis=0, props='background-color:green;')
         .highlight_min(axis=0, props='background-color:red;'))

Dataset,SVHN,SVHN,SVHN,SVHN,SVHN,SVHN,SVHN,SVHN,SVHN,SVHN,SVHN,SVHN,CIFAR100,CIFAR100,CIFAR100,CIFAR100,CIFAR100,CIFAR100,CIFAR100,CIFAR100,CIFAR100,CIFAR100,CIFAR100,CIFAR100
Score,MaxLogit,MaxLogit,MaxLogit,MaxSoftmax,MaxSoftmax,MaxSoftmax,Energy,Energy,Energy,VIM,VIM,VIM,MaxLogit,MaxLogit,MaxLogit,MaxSoftmax,MaxSoftmax,MaxSoftmax,Energy,Energy,Energy,VIM,VIM,VIM
Measure,AUROC,AUPR,FPR,AUROC,AUPR,FPR,AUROC,AUPR,FPR,AUROC,AUPR,FPR,AUROC,AUPR,FPR,AUROC,AUPR,FPR,AUROC,AUPR,FPR,AUROC,AUPR,FPR
Statistics,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean
Seed1,91.187296,91.189539,45.065,92.285029,90.706797,28.345,91.120215,91.021225,45.54,94.904821,94.343339,22.88,87.278997,86.914468,57.62,88.163276,85.92066,44.88,86.992341,86.766238,58.42,85.752845,84.338372,56.805
Seed64,93.770966,91.390747,23.8,93.54737,90.253254,17.71,93.791355,91.456397,23.425,92.96861,91.016117,25.85,86.130041,85.160993,61.44,87.387522,84.049613,44.74,86.367621,85.500403,59.935,86.283511,85.105657,55.835
Seed42,89.432849,87.487592,45.655,88.764348,85.463421,39.99,89.358359,87.399148,46.59,86.364781,81.828119,41.245,87.584438,86.604311,54.845,87.889656,84.969121,43.235,87.288575,86.46195,56.825,83.874845,82.242205,61.63


In [13]:
stats.T

Dataset,SVHN,SVHN,SVHN,SVHN,SVHN,SVHN,SVHN,SVHN,SVHN,SVHN,...,CIFAR100,CIFAR100,CIFAR100,CIFAR100,CIFAR100,CIFAR100,CIFAR100,CIFAR100,CIFAR100,CIFAR100
Score,MaxLogit,MaxLogit,MaxLogit,MaxSoftmax,MaxSoftmax,MaxSoftmax,Energy,Energy,Energy,VIM,...,MaxLogit,MaxSoftmax,MaxSoftmax,MaxSoftmax,Energy,Energy,Energy,VIM,VIM,VIM
Measure,AUROC,AUPR,FPR,AUROC,AUPR,FPR,AUROC,AUPR,FPR,AUROC,...,FPR,AUROC,AUPR,FPR,AUROC,AUPR,FPR,AUROC,AUPR,FPR
Statistics,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,...,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean
Mean,91.463704,90.022626,38.173333,91.532249,88.807824,28.681667,91.42331,89.958923,38.518333,91.412738,...,57.968333,87.813485,84.979798,44.285,86.882846,86.242864,58.393333,85.303734,83.895411,58.09
Std,2.182227,2.197708,12.451167,2.478777,2.905203,11.143815,2.231986,2.227483,13.081749,4.477569,...,3.31127,0.393446,0.935569,0.912017,0.470139,0.660745,1.555171,1.26558,1.482228,3.103856
