In [68]:
import pandas
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchmetrics

In [69]:
model = torchvision.models.resnet152(weights=True).cuda()



In [70]:
normalize = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
    ])

imagenet = torchvision.datasets.ImageNet(root='data', split='val', transform=normalize)

In [71]:
#validation test split
val_loader, test_loader = torch.utils.data.random_split(imagenet, [int(len(imagenet)*0.2), int(len(imagenet)*0.8)])
val_loader = torch.utils.data.DataLoader(val_loader, batch_size=8, shuffle=False, num_workers=1)
test_loader = torch.utils.data.DataLoader(test_loader, batch_size=8, shuffle=False, num_workers=1)


In [72]:
model.eval()
predictions = torch.tensor([])
targets = torch.tensor([])
with torch.no_grad():
    for i, (x, y) in enumerate(test_loader):
        x = x.cuda()
        y_hat = model(x)
        predictions = torch.cat((predictions, y_hat.cpu()))
        targets = torch.cat((targets, y.cpu()))

In [73]:
accScore = torchmetrics.functional.accuracy(predictions, targets, task='multiclass', num_classes=1000)
print(f'Accuracy: {accScore.detach()}')

Accuracy: 0.7852500081062317


In [91]:
predictions.shape

torch.Size([40000, 1000])

In [101]:
from temp_scale import _ECELoss
print('ECE in test set: %.3f' % _ECELoss()(predictions, targets))
print('NLL in test set: %.3f' % F.cross_entropy(predictions, targets.type(torch.LongTensor)))

ECE in test set: 0.049
NLL in test set: 0.873


In [77]:
from temp_scale import ModelWithTemperature

temp_calib_model = ModelWithTemperature(model)

In [78]:
temp_calib_model = temp_calib_model.set_temperature(val_loader)

Before temperature - NLL: 0.890, ECE: 0.054
Optimal temperature: 1.377
After temperature - NLL: 0.879, ECE: 0.035


In [79]:
temp_calib_model.eval()
ts_predictions = torch.tensor([])
ts_targets = torch.tensor([])
with torch.no_grad():
    for i, (x, y) in enumerate(test_loader):
        x = x.cuda()
        y_hat = temp_calib_model(x)
        ts_predictions = torch.cat((ts_predictions, y_hat.cpu()))
        ts_targets = torch.cat((ts_targets, y.cpu()))

In [80]:
accScore = torchmetrics.functional.accuracy(ts_predictions, ts_targets, task='multiclass', num_classes=1000)
print(f'Accuracy: {accScore.detach()}')

Accuracy: 0.7852500081062317


In [102]:
from temp_scale import _ECELoss
print('ECE in test set: %.3f' % _ECELoss()(ts_predictions, ts_targets))
print('NLL in test set: %.3f' % F.cross_entropy(ts_predictions, ts_targets.type(torch.LongTensor)))

ECE in test set: 0.041
NLL in test set: 0.861
