In [5]:
import pandas
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
path = sys.path.insert(1,"C:/PythonEnviroments/pytorch_resnet_cifar10")
from resnet import resnet110
from torchvision import transforms
import torchmetrics
import dill

In [9]:
# the network architecture coresponding to the checkpoint
model = resnet110()

# remember to set map_location
check_point = torch.load('models/resnet110-1d1ed7c2.pth', map_location='cuda:0')

# cause the model are saved from Parallel, we need to wrap it
model = torch.nn.DataParallel(model)
model.load_state_dict(check_point['state_dict'])

# pay attention to .module! without this, if you load the model, it will be attached with [Parallel.module]
# that will lead to some trouble!
torch.save(model.module, 'resnet110.pth', pickle_module=dill)

# load the converted pretrained model
model = torch.load('resnet110.pth', map_location='cuda:0')


In [7]:

import torchvision.transforms as transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])

train_set = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True,
    transform=transform,
)
test_set = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True,
    transform=transform
)

test_set, val_set = torch.utils.data.random_split(test_set, [int(len(test_set) * 0.5), int(len(test_set) * 0.5)])

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=128, shuffle=False, num_workers=2)
val_loader = torch.utils.data.DataLoader(
    val_set, batch_size=128, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [10]:
model.eval()
predictions = torch.tensor([])
targets = torch.tensor([])
with torch.no_grad():
    for i, (x, y) in enumerate(test_loader):
        x = x.cuda()
        y_hat = model(x)
        predictions = torch.cat((predictions, y_hat.cpu()))
        targets = torch.cat((targets, y.cpu()))

In [12]:
accScore = torchmetrics.functional.accuracy(predictions, targets, task='multiclass', num_classes=10)
print(f'Accuracy: {accScore.detach()}')

Accuracy: 0.9363999962806702


In [13]:
from temp_scale import _ECELoss
print('ECE in test set: %.3f' % _ECELoss()(predictions, targets))
print('NLL in test set: %.3f' % F.cross_entropy(predictions, targets.type(torch.LongTensor)))

ECE in test set: 0.044
NLL in test set: 0.299


In [14]:
from temp_scale import ModelWithTemperature

temp_calib_model = ModelWithTemperature(model)

In [15]:
temp_calib_model = temp_calib_model.set_temperature(val_loader)

Before temperature - NLL: 0.312, ECE: 0.043
Optimal temperature: 1.733
After temperature - NLL: 0.214, ECE: 0.025


In [16]:
temp_calib_model.eval()
ts_predictions = torch.tensor([])
ts_targets = torch.tensor([])
with torch.no_grad():
    for i, (x, y) in enumerate(test_loader):
        x = x.cuda()
        y_hat = temp_calib_model(x)
        ts_predictions = torch.cat((ts_predictions, y_hat.cpu()))
        ts_targets = torch.cat((ts_targets, y.cpu()))

In [18]:
accScore = torchmetrics.functional.accuracy(ts_predictions, ts_targets, task='multiclass', num_classes=10)
print(f'Accuracy: {accScore.detach()}')

Accuracy: 0.9363999962806702


In [19]:
from temp_scale import _ECELoss
print('ECE in test set: %.3f' % _ECELoss()(ts_predictions, ts_targets))
print('NLL in test set: %.3f' % F.cross_entropy(ts_predictions, ts_targets.type(torch.LongTensor)))

ECE in test set: 0.025
NLL in test set: 0.206
