# Anomaly Score comparison

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join(".."))
sys.path.append(module_path)
import torch
from torchvision import datasets

from anomaly_scores.energy import energy_anomaly_score
from anomaly_scores.max_logit import max_logit_anomaly_score
from anomaly_scores.softmax import max_softmax_anomaly_score
from anomaly_scores.vim_scores import VIM
from energy_ood.CIFAR.models.wrn import WideResNet
from energy_ood.utils.svhn_loader import SVHN
from util import TEST_TRANSFORM
from util.display_results import compare_all_results
from util.get_ood_score import get_ood_score_for_multiple_datasets
from vim_training.test import test

## The data
Let's start with replicating the results from the paper. First, with the SVHN data set. 

In [2]:
loaders = []

id_data = datasets.CIFAR10("../data/cifar10", train=False, transform=TEST_TRANSFORM)
id_loader = torch.utils.data.DataLoader(
    id_data, batch_size=200, shuffle=False, num_workers=2, pin_memory=True
)
loaders.append(("CIFAR10", id_loader))


ood_data = SVHN(
    root="../data/svhn/",
    split="test",
    transform=TEST_TRANSFORM,
    download=False,
)
ood_loader = torch.utils.data.DataLoader(
    ood_data, batch_size=200, shuffle=True, num_workers=2, pin_memory=True
)
ood_num_examples = len(loaders[0][1].dataset) // 5
loaders.append(("SVHN", ood_loader))


data = datasets.CIFAR100("../data/cifar-100", train=False, transform=TEST_TRANSFORM)
loader = torch.utils.data.DataLoader(
    data, batch_size=200, shuffle=True, num_workers=2, pin_memory=True
)
loaders.append(("CIFAR100", loader))

## Models

We are using the Wide ResNet as in the paper.

In [3]:
models = []

model_folder = "../snapshots/pretrained/"

for filename in  next(os.walk(model_folder), (None, None, []))[2]:
    model_name = filename.split(".")[0].replace("_"," ")
    print(model_name)
    model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.3)
    model.load_state_dict(
        torch.load(
            model_folder + filename
        )
    )
    model.eval()
    _ = model.cuda()
    models.append((model_name, model))

WRN Hendrycks Calib Seed1
WRN Ours Seed1
WRN Hendrycks Seed1
WRN Hendrycks Seed64
WRN Hendrycks Seed42


# Anomaly Scores
Let's compare the scores.

In [4]:
import numpy as np
aurocs_results = {}
for model_name, model in models:
    print(model_name)
    aurocs_results[model_name] = {}
    vim = VIM(id_loader, model)

    scores = [
        ("MaxLogit", max_logit_anomaly_score),
        ("MaxSoftmax", max_softmax_anomaly_score),
        ("Energy", energy_anomaly_score),
        ("VIM", vim.compute_anomaly_score),
    ]

    _, test_accuracy = test(model, loaders[0][1])
    aurocs_results[model_name]["test_acc"] = test_accuracy

    for name, score in scores:
        print("  ", name)
        results = get_ood_score_for_multiple_datasets(
            loaders,
            model,
            score,
            is_using="last" if not name == "VIM" else "last_penultimate",
        )
        aurocs = [np.mean(aurocs) for aurocs, _, _ in results]
        aurocs.append(np.mean(aurocs))
        aurocs_results[model_name][name] = aurocs

WRN Hendrycks Calib Seed1
   MaxLogit
   MaxSoftmax
   Energy
   VIM
WRN Ours Seed1
   MaxLogit
   MaxSoftmax
   Energy
   VIM
WRN Hendrycks Seed1
   MaxLogit
   MaxSoftmax
   Energy
   VIM
WRN Hendrycks Seed64
   MaxLogit
   MaxSoftmax
   Energy
   VIM
WRN Hendrycks Seed42
   MaxLogit
   MaxSoftmax
   Energy
   VIM


In [5]:
compare_all_results(aurocs_results, loaders)

WRN Hendrycks Calib Seed1 (5.58%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    86.57%    |    85.27%    |    85.92%   
               MaxSoftmax |    88.17%    |   *87.08%    |    87.63%   
                   Energy |    85.98%    |    85.98%    |    85.98%   
                      VIM |   *90.28%    |    85.68%    |   *87.98%   

   WRN Ours Seed1 (5.18%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    81.56%    |   *87.78%    |    84.67%   
               MaxSoftmax |    86.82%    |    87.55%    |    87.18%   
                   Energy |    81.64%    |    86.88%    |    84.26%   
                      VIM |   *96.15%    |    86.95%    |   *91.55%   

WRN Hendrycks Seed1 (5.15%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    91.04%    |    86.89%    |    88.97%   
               MaxSoftmax |    92.18%    |   *88.40%    |    90.29%   
                   Energy |    91.19%    |    86.82%    |    89.0