# Energy Out-of-distribution detection

From the paper [Energy-based Out-of-distribution Detection](http://arxiv.org/abs/2010.03759). Let's explore the folder structure, the models and the available scripts and extract what we need for our ideas.

In [1]:
import torch
from torchvision import datasets, transforms

from anomaly_scores.max_logit import max_logit_anomaly_score
from anomaly_scores.energy import energy_anomaly_score
from anomaly_scores.softmax import max_softmax_anomaly_score
from energy_ood.CIFAR.models.wrn import WideResNet
from energy_ood.utils.display_results import get_measures, show_performance
from energy_ood.utils.svhn_loader import SVHN
from util.get_ood_score import get_ood_scores

In [2]:
model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.3)
model.load_state_dict(
    torch.load(
        "energy_ood/CIFAR/snapshots/pretrained/cifar10_wrn_pretrained_epoch_99.pt",
        map_location=torch.device("cpu"),
    )
)
model.eval()
_ = model.cuda()

Let's start with replicating the results from the paper. First, with the SVHN data set. 

In [3]:
# mean and standard deviation of channels of CIFAR-10 images
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
test_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean, std)]
)

id_data = datasets.CIFAR10("data/cifar10", train=False, transform=test_transform)
id_loader = torch.utils.data.DataLoader(id_data, batch_size=200, shuffle=False,
                                          num_workers=2, pin_memory=True)
ood_data = SVHN(
    root="data/svhn/",
    split="test",
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean, std)]  # trn.Resize(32),
    ),
    download=False,
)
ood_loader = torch.utils.data.DataLoader(
    ood_data, batch_size=200, shuffle=True, num_workers=2, pin_memory=True
)

In [4]:
runs = 10

Max logit anomaly score:

In [5]:
in_score = get_ood_scores(id_loader, max_logit_anomaly_score, model, use_penultimate=True)
out_score = get_ood_scores(ood_loader, max_logit_anomaly_score, model, use_penultimate=True)
show_performance(out_score, in_score, method_name="Max Logit")

			Max Logit
FPR95:			43.25
AUROC:			91.28
AUPR:			96.14


In [6]:
in_score = get_ood_scores(id_loader, max_softmax_anomaly_score, model, use_penultimate=True)
out_score = get_ood_scores(ood_loader, max_softmax_anomaly_score, model, use_penultimate=True)
show_performance(out_score, in_score, method_name="Softmax")

			Softmax
FPR95:			29.50
AUROC:			92.17
AUPR:			96.03


In [7]:
in_score = get_ood_scores(id_loader, energy_anomaly_score, model, use_penultimate=True)
out_score = get_ood_scores(ood_loader, energy_anomaly_score, model, use_penultimate=True)
show_performance(out_score, in_score, method_name="Energy")

			Energy
FPR95:			45.35
AUROC:			91.45
AUPR:			96.22
