# Anomaly Score comparison

In [1]:
import torch
from torchvision import datasets, transforms

from anomaly_scores.energy import energy_anomaly_score
from anomaly_scores.max_logit import max_logit_anomaly_score
from anomaly_scores.softmax import max_softmax_anomaly_score
from anomaly_scores.vim_scores import VIM
from energy_ood.CIFAR.models.wrn import WideResNet
from energy_ood.utils.svhn_loader import SVHN
from util import TEST_TRANSFORM
from util.get_ood_score import get_ood_score_for_multiple_datasets

## Model

We are using the Wide ResNet as in the paper.

In [2]:
model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.3)
model.load_state_dict(
    torch.load(
        "energy_ood/CIFAR/snapshots/pretrained/cifar10_wrn_pretrained_epoch_99.pt"
    )
)
model.eval()
_ = model.cuda()

## The data
Let's start with replicating the results from the paper. First, with the SVHN data set. 

In [3]:
loaders = []

id_data = datasets.CIFAR10("data/cifar10", train=False, transform=TEST_TRANSFORM)
id_loader = torch.utils.data.DataLoader(
    id_data, batch_size=200, shuffle=False, num_workers=2, pin_memory=True
)
loaders.append(("CIFAR10", id_loader))


ood_data = SVHN(
    root="data/svhn/",
    split="test",
    transform=TEST_TRANSFORM,
    download=False,
)
ood_loader = torch.utils.data.DataLoader(
    ood_data, batch_size=200, shuffle=True, num_workers=2, pin_memory=True
)
ood_num_examples = len(loaders[0][1].dataset) // 5
loaders.append(("SVHN", ood_loader))

# Anomaly Scores
Let's compare the scores.

In [4]:
vim = VIM(id_loader, model)

scores = [
    ("MaxLogit", max_logit_anomaly_score),
    ("MaxSoftmax", max_softmax_anomaly_score),
    ("Energy", energy_anomaly_score),
    ("VIM", vim.compute_anomaly_score),
]

all_anomaly_results = {"WRN": {}}

for name, score in scores:
    all_anomaly_results["WRN"][name] = get_ood_score_for_multiple_datasets(
        loaders,
        model,
        score,
        is_using="last" if not name == "VIM" else "last_penultimate",
    )

In [5]:
from util.display_results import compare_all_results

compare_all_results(all_anomaly_results, loaders)

                      WRN |     SVHN     |     AVG     
                 MaxLogit |    91.45%    |    91.45%   
               MaxSoftmax |    92.34%    |    92.34%   
                   Energy |    91.30%    |    91.30%   
                      VIM |   *94.57%    |   *94.57%   


* highlights the maximum AUROC Score for an OOD Dataset


OK, only the virutal logit matching score achieves an improved score compared to the others.