In [None]:
import warnings
warnings.simplefilter("ignore", UserWarning)
import pandas as pd
import torch
import torch.distributions as dists
import numpy as np
import helper.wideresnet as wrn
import helper.dataloaders as dl
from helper import util
from helper.calibration_gp_utils import predict, gp_calibration_eval
from netcal.metrics import ECE

from laplace import Laplace

: 

In [None]:
np.random.seed(7777)
torch.manual_seed(7777)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

: 

In [None]:
print(torch.has_mps)
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name(0))

: 

In [None]:
train_loader = dl.CIFAR10(train=True)
test_loader = dl.CIFAR10(train=False)

: 

In [None]:
# The model is a standard WideResNet 16-4
# Taken as is from https://github.com/hendrycks/outlier-exposure
model = wrn.WideResNet(16, 4, num_classes=10).cuda().eval()
# print( sum(p.numel() for p in model.parameters()))

util.download_pretrained_model()
model.load_state_dict(torch.load('./temp/CIFAR10_plain.pt'))

: 

In [None]:

targets = torch.cat([y for x, y in test_loader], dim=0).cpu()
probs_map = predict(test_loader, model, laplace=False)
acc_map = (probs_map.argmax(-1) == targets).float().mean()
ece_map = ECE(bins=15).measure(probs_map.numpy(), targets.numpy())
nll_map = -dists.Categorical(probs_map).log_prob(targets).mean()

print(f'[MAP] Acc.: {acc_map:.1%}; ECE: {ece_map:.1%}; NLL: {nll_map:.3}')

: 

In [None]:
metrics_gp = gp_calibration_eval(model=model, train_loader=train_loader, test_loader=test_loader)

: 

In [None]:
metrics_gp

: 

: 