# Anomaly Score comparison

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join(".."))
sys.path.append(module_path)
import torch
from torchvision import datasets

from anomaly_scores.energy import energy_anomaly_score
from anomaly_scores.max_logit import max_logit_anomaly_score
from anomaly_scores.softmax import max_softmax_anomaly_score
from anomaly_scores.vim_scores import VIM
from energy_ood.CIFAR.models.wrn import WideResNet
from energy_ood.utils.svhn_loader import SVHN
from energy_ood.utils.tinyimages_80mn_loader import TinyImages
from util import TEST_TRANSFORM, TINY_TRANSFORM
from util.display_results import compare_all_results
from util.get_ood_score import get_ood_score_for_multiple_datasets
from vim_training.test import test

## The data
Let's start with replicating the results from the paper. First, with the SVHN data set. 

In [2]:
loaders = []

id_data = datasets.CIFAR10("../data/cifar10", train=False, transform=TEST_TRANSFORM)
id_loader = torch.utils.data.DataLoader(
    id_data, batch_size=200, shuffle=False, num_workers=2, pin_memory=True
)
loaders.append(("CIFAR10", id_loader))


ood_data = SVHN(
    root="../data/svhn/",
    split="test",
    transform=TEST_TRANSFORM,
    download=False,
)
ood_loader = torch.utils.data.DataLoader(
    ood_data, batch_size=200, shuffle=True, num_workers=2, pin_memory=True
)
ood_num_examples = len(loaders[0][1].dataset) // 5
loaders.append(("SVHN", ood_loader))


data = datasets.CIFAR100("../data/cifar-100", train=False, transform=TEST_TRANSFORM)
loader = torch.utils.data.DataLoader(
    data, batch_size=200, shuffle=True, num_workers=2, pin_memory=True
)
loaders.append(("CIFAR100", loader))

## Models

We are using the Wide ResNet as in the paper.

In [3]:
models = []
model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.3)
model.load_state_dict(
    torch.load(
        "../energy_ood/CIFAR/snapshots/pretrained/cifar10_wrn_pretrained_epoch_99.pt"
    )
)
model.eval()
_ = model.cuda()
models.append(("WRN", model))


model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.3)
model.load_state_dict(
    torch.load(
        "../../interesting_papers/outlier-exposure/CIFAR/snapshots/baseline/cifar10_wrn_baseline_epoch_99.pt"
    )
)
model.eval()
_ = model.cuda()
models.append(("WRN Hendr.", model))


model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.3)
model.load_state_dict(
    torch.load(
        "../../interesting_papers/outlier-exposure/CIFAR/snapshots/baseline/cifar10_calib_wrn_baseline_epoch_99.pt"
    )
)
model.eval()
_ = model.cuda()
models.append(("WRN Hendr. Calib", model))

model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.3)
model.load_state_dict(torch.load("../snapshots/default/CIFAR10_WRN_epoch_99.pt"))
model.eval()
_ = model.cuda()
models.append(("WRN ours", model))


model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.3)
model.load_state_dict(
    torch.load(
        "../energy_ood/CIFAR/snapshots/energy_ft/cifar10_wrn_s1_energy_ft_epoch_9.pt"
    )
)
model.eval()
_ = model.cuda()
models.append(("Energy_ft", model))


model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.3)
model.load_state_dict(torch.load("../snapshots/energy/CIFAR10_WRN_epoch_9.pt"))
model.eval()
_ = model.cuda()
models.append(("Energy_ft_ours", model))


# model = WideResVIMNet(depth=40, num_classes=10, loader=id_loader, widen_factor=2, dropRate=0.3)
# model.load_state_dict(
#     torch.load(
#         "../snapshots/vim/CIFAR10_WRN_epoch_9.pt"
#     )
# )
# model.eval()
# _ = model.cuda()
# models.append(("ViM_ft", model))


model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.3)
model.load_state_dict(torch.load("../snapshots/vanilla_ft/CIFAR10_WRN_epoch_9.pt"))
model.eval()
_ = model.cuda()
models.append(("vanilla_ft", model))

# Anomaly Scores
Let's compare the scores.

In [4]:
all_anomaly_results = {}
for model_name, model in models:
    print(model_name)
    all_anomaly_results[model_name] = {}
    vim = VIM(id_loader, model)

    scores = [
        ("MaxLogit", max_logit_anomaly_score),
        ("MaxSoftmax", max_softmax_anomaly_score),
        ("Energy", energy_anomaly_score),
        ("VIM", vim.compute_anomaly_score),
    ]

    _, test_accuracy = test(model, loaders[0][1])
    all_anomaly_results[model_name]["test_acc"] = test_accuracy

    for name, score in scores:
        print("  ", name)
        all_anomaly_results[model_name][name] = get_ood_score_for_multiple_datasets(
            loaders,
            model,
            score,
            is_using="last" if not name == "VIM" else "last_penultimate",
        )

WRN
   MaxLogit
   MaxSoftmax
   Energy
   VIM
WRN Hendr.
   MaxLogit
   MaxSoftmax
   Energy
   VIM
WRN Hendr. Calib
   MaxLogit
   MaxSoftmax
   Energy
   VIM
WRN ours
   MaxLogit
   MaxSoftmax
   Energy
   VIM
Energy_ft
   MaxLogit
   MaxSoftmax
   Energy
   VIM
Energy_ft_ours
   MaxLogit
   MaxSoftmax
   Energy
   VIM
vanilla_ft
   MaxLogit


In [None]:
compare_all_results(all_anomaly_results, loaders)

              WRN (5.15%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    91.17%    |    87.25%    |    89.21%   
               MaxSoftmax |    91.98%    |   *88.09%    |    90.03%   
                   Energy |    91.27%    |    87.01%    |    89.14%   
                      VIM |   *95.02%    |    85.61%    |   *90.32%   

       WRN Hendr. (5.15%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    90.98%    |    87.19%    |    89.09%   
               MaxSoftmax |    92.22%    |   *88.19%    |    90.21%   
                   Energy |    91.03%    |    87.03%    |    89.03%   
                      VIM |   *94.84%    |    85.82%    |   *90.33%   

 WRN Hendr. Calib (5.58%) |     SVHN     |   CIFAR100   |     AVG     
                 MaxLogit |    86.75%    |    85.37%    |    86.06%   
               MaxSoftmax |    87.89%    |   *87.21%    |    87.55%   
                   Energy |    86.32%    |    85.07%    |    85.69%   
    