# Anomaly Score comparison

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join(".."))
sys.path.append(module_path)
import torch
from torchvision import datasets

from anomaly_scores.energy import energy_anomaly_score
from anomaly_scores.max_logit import max_logit_anomaly_score
from anomaly_scores.softmax import max_softmax_anomaly_score
from anomaly_scores.vim_scores import VIM
from energy_ood.CIFAR.models.wrn import WideResNet
from energy_ood.utils.svhn_loader import SVHN
from util import TEST_TRANSFORM
from vim_training.model import WideResVIMNet

from util.display_results import compare_all_results
from util.get_ood_score import get_ood_score_for_multiple_datasets
from vim_training.testing import test

## The data
Let's start with replicating the results from the paper. First, with the SVHN data set. 

In [2]:
loaders = []

id_data = datasets.CIFAR10("../data/cifar10", train=False, transform=TEST_TRANSFORM)
id_loader = torch.utils.data.DataLoader(
    id_data, batch_size=200, shuffle=False, num_workers=2, pin_memory=True
)
loaders.append(("CIFAR10", id_loader))


ood_data = SVHN(
    root="../data/svhn/",
    split="test",
    transform=TEST_TRANSFORM,
    download=False,
)
ood_loader = torch.utils.data.DataLoader(
    ood_data, batch_size=200, shuffle=True, num_workers=2, pin_memory=True
)
ood_num_examples = len(loaders[0][1].dataset) // 5
loaders.append(("SVHN", ood_loader))


data = datasets.CIFAR100("../data/cifar-100", train=False, transform=TEST_TRANSFORM)
loader = torch.utils.data.DataLoader(
    data, batch_size=200, shuffle=True, num_workers=2, pin_memory=True
)
loaders.append(("CIFAR100", loader))

## Models

We are using the Wide ResNet as in the paper.

In [3]:
models = []

model = WideResVIMNet(depth=40, num_classes=10, loader=loaders[0][1], widen_factor=2, dropRate=0.3, is_using_vim=False)
model.load_state_dict(
    torch.load(
        "../snapshots/pretrain_vim/CIFAR10_WRN_epoch_99.pt"
    )
)
model.eval()
_ = model.cuda()
models.append(("WRVM", model))


model = WideResVIMNet(depth=40, num_classes=10, loader=loaders[0][1], widen_factor=2, dropRate=0.3, is_using_vim=True)
model.load_state_dict(
    torch.load(
        "../snapshots/train_with_vim/CIFAR10_WRN_epoch_99.pt"
    )
)
model.eval()
_ = model.cuda()
models.append(("WRVM trained with VIM", model))

model = WideResVIMNet(depth=40, num_classes=10, loader=loaders[0][1], widen_factor=2, dropRate=0.3, is_using_vim=True)
model.load_state_dict(
    torch.load(
        "../snapshots/vim_ft/CIFAR10_WRN_epoch_9.pt"
    )
)
model.eval()
_ = model.cuda()
models.append(("WRVM fine-tuned with VIM", model))

# Anomaly Scores
Let's compare the scores.

In [4]:
import numpy as np

aurocs_results = {}
auprs_results = {}
all_results = []
for model_name, model in models:
    print(model_name)
    aurocs_results[model_name] = {}
    auprs_results[model_name] = {}

    if "with VIM" in model_name:
        vim_scoring = model.compute_anomaly_score
    else:
        vim = VIM(id_loader, model)
        vim_scoring = vim.compute_anomaly_score

    scores = [
        ("MaxLogit", max_logit_anomaly_score),
        ("MaxSoftmax", max_softmax_anomaly_score),
        ("Energy", energy_anomaly_score),
        ("VIM", vim_scoring),
    ]

    _, test_accuracy = test(model, loaders[0][1])
    aurocs_results[model_name]["test_acc"] = test_accuracy
    auprs_results[model_name]["test_acc"] = test_accuracy
    model_results = []

    for score_name, score in scores:
        print("  ", score_name)
        results = get_ood_score_for_multiple_datasets(
            loaders,
            model,
            score,
            is_using="last" if not score_name == "VIM" or "with VIM" in model_name else "last_penultimate",
            runs=3,
        )
        aurocs = [np.mean(aurocs) for aurocs, _, _ in results]
        aurocs.append(np.mean(aurocs))
        aurocs_results[model_name][score_name] = aurocs
        auprs = [np.mean(auprs) for _, auprs, _ in results]
        auprs.append(np.mean(auprs))
        auprs_results[model_name][score_name] = auprs
        model_results.append(results)
    all_results.append(model_results)

WRVM
   MaxLogit
   MaxSoftmax
   Energy
   VIM
WRVM trained with VIM
   MaxLogit
   MaxSoftmax
   Energy
   VIM
WRVM fine-tuned with VIM
   MaxLogit
   MaxSoftmax
   Energy
   VIM


In [5]:
compare_all_results(aurocs_results, loaders)

             WRVM (5.01%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    90.11%    |    86.62%    |    88.37%   
               MaxSoftmax |    89.99%    |   *87.56%    |    88.77%   
                   Energy |    89.98%    |    86.21%    |    88.09%   
                      VIM |   *90.61%    |    87.13%    |   *88.87%   

WRVM trained with VIM (9.61%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    77.09%    |    65.78%    |    71.44%   
               MaxSoftmax |   *82.20%    |   *78.63%    |   *80.41%   
                   Energy |    76.93%    |    65.14%    |    71.04%   
                      VIM |    44.02%    |    44.50%    |    44.26%   

WRVM fine-tuned with VIM (5.05%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    53.86%    |    76.37%    |    65.11%   
               MaxSoftmax |    19.01%    |    52.24%    |    35.62%   
                   Energy |    54.85%    |    76.95%    |    65.

In [6]:
compare_all_results(auprs_results, loaders)

             WRVM (5.01%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    87.68%    |    86.40%    |    87.04%   
               MaxSoftmax |    86.22%    |    85.00%    |    85.61%   
                   Energy |    87.51%    |    85.89%    |    86.70%   
                      VIM |   *87.89%    |   *86.43%    |   *87.16%   

WRVM trained with VIM (9.61%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    63.11%    |    56.40%    |    59.75%   
               MaxSoftmax |   *80.43%    |   *78.25%    |   *79.34%   
                   Energy |    62.82%    |    55.65%    |    59.24%   
                      VIM |    43.68%    |    46.06%    |    44.87%   

WRVM fine-tuned with VIM (5.05%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    50.57%    |    73.10%    |    61.83%   
               MaxSoftmax |    35.34%    |    56.82%    |    46.08%   
                   Energy |    51.68%    |    74.24%    |    62.

# Latex Prepartion

In [7]:
import numpy as np
from energy_ood.utils.display_results import get_measures

from util.get_ood_score import get_ood_scores

RUNS = 10
all_results = {}
for model_name, model in models:
    print(model_name)

    if "with VIM" in model_name:
        vim_scoring = model.compute_anomaly_score
    else:
        vim = VIM(id_loader, model)
        vim_scoring = vim.compute_anomaly_score

    scores = [
        ("MaxLogit", max_logit_anomaly_score),
        ("MaxSoftmax", max_softmax_anomaly_score),
        ("Energy", energy_anomaly_score),
        ("VIM", vim_scoring),
    ]

         

    # _, test_accuracy = test(model, loaders[0][1])

    ood_num_examples = len(loaders[0][1].dataset) // 5
    all_results[model_name] = {}
    for score_name, score in scores:    
        is_using="last" if not score_name == "VIM" or "with VIM" in model_name else "last_penultimate"
        print("  ", score_name)
        in_score = get_ood_scores(
            loaders[0][1],
            model,
            score,
            ood_num_examples,
            in_dist=True,
            is_using=is_using,
        )
        for ds_name, loader in loaders[1:]:
            if ds_name not in all_results[model_name].keys():
                all_results[model_name][ds_name] = {}
            print("    ", ds_name)
            all_results[model_name][ds_name][score_name] = {
                "AUROC": [],
                "AUPR": [],
                "FPR": [],
            }
            for _ in range(RUNS):
                out_score = get_ood_scores(
                    loader,
                    model,
                    score,
                    ood_num_examples,
                    is_using=is_using,
                )
                auroc, aupr, fpr = get_measures(out_score[:], in_score[:])
                all_results[model_name][ds_name][score_name]["AUROC"].append(auroc)
                all_results[model_name][ds_name][score_name]["AUPR"].append(aupr)
                all_results[model_name][ds_name][score_name]["FPR"].append(fpr)

WRVM
   MaxLogit
     SVHN
     CIFAR100
   MaxSoftmax
     SVHN
     CIFAR100
   Energy
     SVHN
     CIFAR100
   VIM
     SVHN
     CIFAR100
WRVM trained with VIM
   MaxLogit
     SVHN
     CIFAR100
   MaxSoftmax
     SVHN
     CIFAR100
   Energy
     SVHN
     CIFAR100
   VIM
     SVHN
     CIFAR100
WRVM fine-tuned with VIM
   MaxLogit
     SVHN
     CIFAR100
   MaxSoftmax
     SVHN
     CIFAR100
   Energy
     SVHN
     CIFAR100
   VIM
     SVHN
     CIFAR100


In [8]:
HAS_STD = False
MEASURES = ["AUROC"]

In [9]:
mean_std_results = []
for i in all_results.keys():
    res = []
    for j in all_results[i].keys():
        for k in all_results[i][j].keys():
            for l in MEASURES:
                res.append(np.mean(all_results[i][j][k][l]))
                if HAS_STD:
                    res.append(np.std(all_results[i][j][k][l]))
    mean_std_results.append(res)

In [10]:
import pandas as pd

col_stats = ["Mean"]
if HAS_STD:
    col_stats.append("Std")
cols = pd.MultiIndex.from_tuples(
    [
        (j, k, l, m)
        for i in [next(iter(all_results))]
        for j in all_results[i].keys()
        for k in all_results[i][j].keys()
        for l in MEASURES
        for m in col_stats
    ],
    names=["Dataset", "Score", "Measure", "Statistics"],
)
df = pd.DataFrame(np.array(mean_std_results)*100, columns=cols, index=all_results.keys()
)
stats = pd.DataFrame()
stats["Mean"] = df.mean(axis=0)
stats["Std"] = df.std(axis=0)
(df.style.highlight_max(axis=0, props='background-color:green;')
         .highlight_min(axis=0, props='background-color:red;'))

Dataset,SVHN,SVHN,SVHN,SVHN,CIFAR100,CIFAR100,CIFAR100,CIFAR100
Score,MaxLogit,MaxSoftmax,Energy,VIM,MaxLogit,MaxSoftmax,Energy,VIM
Measure,AUROC,AUROC,AUROC,AUROC,AUROC,AUROC,AUROC,AUROC
Statistics,Mean,Mean,Mean,Mean,Mean,Mean,Mean,Mean
WRVM,89.951248,90.018159,90.025974,90.535066,86.377777,87.790641,86.550521,86.916533
WRVM trained with VIM,77.253136,82.263391,76.910951,44.070554,65.644625,78.541579,64.747599,45.127237
WRVM fine-tuned with VIM,53.979533,19.394994,54.393842,98.974823,76.713999,52.84355,76.563275,93.81421
