# Confidence calibration evluation

In [1]:
import torch
from torchvision import datasets, transforms
from anomaly_scores.vim_scores import VIM
from energy_ood.CIFAR.models.wrn import WideResNet

from energy_ood.utils.svhn_loader import SVHN
from util.confidence import compute_rms_calibration_error
from util.get_ood_score import to_np
import numpy as np

## Model

We are using the Wide ResNet as in the paper.

In [2]:
model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.3)
model.load_state_dict(
    torch.load(
        "energy_ood/CIFAR/snapshots/pretrained/cifar10_wrn_pretrained_epoch_99.pt"
    )
)
model.eval()
_ = model.cuda()

## The data

In [3]:
# mean and standard deviation of channels of CIFAR-10 images
torch.cuda.empty_cache()
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
test_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean, std)]
)

test_dataset = datasets.CIFAR10("data/cifar10", train=False, transform=test_transform)
data_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=2500, shuffle=False, num_workers=2, pin_memory=True
)

vim = VIM(data_loader, model)

## On vanilla model

In [4]:
error_vanilla = []
error_vim = []
for input, labels in data_loader:
    if torch.cuda.is_available():
        input = input.cuda()
    logits, penultimate = model(input)

    probabilities = torch.nn.functional.softmax(logits, dim=-1)

    error_vanilla.append(compute_rms_calibration_error(to_np(labels), to_np(probabilities)))

    virtual_logit = vim.compute_anomaly_score(logits, penultimate)
    virutal_logits = torch.hstack((logits.cpu(), torch.from_numpy(np.expand_dims(virtual_logit, axis=1))))
    virutal_probabilities = torch.nn.functional.softmax(virutal_logits, dim=-1)
    error_vim.append(compute_rms_calibration_error(to_np(labels), to_np(virutal_probabilities[:, :10])))

In [5]:
print(np.mean(error_vanilla),np.std(error_vanilla))

0.04857253001692445 0.006443243636239018


In [6]:
print(np.mean(error_vim),np.std(error_vim))

0.05959132114844886 0.0017581347136628384


Ok, so vim does not improve it. It might even be worse. But the reason could be, that the probabilities are not fully distributed over the entire space. That's why we have some bins around low calibration with just a few samples

## With OOD dataset

In [7]:
ood_data = SVHN(
    root="data/svhn/",
    split="test",
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean, std)]  # trn.Resize(32),
    ),
    download=False,
)
ood_loader = torch.utils.data.DataLoader(
    ood_data, batch_size=2000, shuffle=True, num_workers=2, pin_memory=True
)

In [8]:
error_vanilla = []
error_vim = []
for input, labels in ood_loader:
    if torch.cuda.is_available():
        input = input.cuda()
    logits, penultimate = model(input)

    probabilities = torch.nn.functional.softmax(logits, dim=-1)

    error_vanilla.append(compute_rms_calibration_error(to_np(labels), to_np(probabilities)))

    virtual_logit = vim.compute_anomaly_score(logits, penultimate)
    virutal_logits = torch.hstack((logits.cpu(), torch.from_numpy(np.expand_dims(virtual_logit, axis=1))))
    virutal_probabilities = torch.nn.functional.softmax(virutal_logits, dim=-1)
    error_vim.append(compute_rms_calibration_error(to_np(labels), to_np(virutal_probabilities[:, :10])))

In [9]:
print(np.mean(error_vanilla),np.std(error_vanilla))
print(np.mean(error_vim),np.std(error_vim))

0.6884658125758784 0.010476772860678628
0.5916646976370383 0.008349233911948351


That's how I expected it! Improvement of around 10 %-points.