# Temperature Scaling

### Mount drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### Import required libraries

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from datetime import date
from itertools import product
import os
import torchvision.models as tmodels
from functools import partial
import collections

Below code reused from https://github.com/gpleiss/temperature_scaling

" On Calibration of Modern Neural Networks"
Chuan Guo, Geoff Pleiss, Yu Sun, Kilian Q. Weinberger
https://arxiv.org/abs/1706.04599

In [None]:
%run '/content/drive/MyDrive/KASHIKO/MODULES/temperature_scaling.py'

### Prepare data

In [None]:
norm_param_dataset_ref = "AVG"
dataset_name = "TEST_0_FINAL"

In [None]:
norm_param_df = pd.read_csv('/content/drive/MyDrive/KASHIKO/DATASET/TRG_DATASET_NORM_PARAM.csv')

meanR = norm_param_df.loc[norm_param_df["Dataset"] == str(norm_param_dataset_ref), "meanR"].item()
meanG = norm_param_df.loc[norm_param_df["Dataset"] == str(norm_param_dataset_ref), "meanG"].item()
meanB = norm_param_df.loc[norm_param_df["Dataset"] == str(norm_param_dataset_ref), "meanB"].item()

stdR = norm_param_df.loc[norm_param_df["Dataset"] == str(norm_param_dataset_ref), "stdR"].item()
stdG = norm_param_df.loc[norm_param_df["Dataset"] == str(norm_param_dataset_ref), "stdG"].item()
stdB = norm_param_df.loc[norm_param_df["Dataset"] == str(norm_param_dataset_ref), "stdB"].item()

In [None]:
dataset = datasets.ImageFolder(
    '/content/drive/MyDrive/KASHIKO/DATASET/' + dataset_name,
    transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((meanR, meanG, meanB), (stdR, stdG, stdB))
    ])
)
_, short_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 10, 10])
_, long_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 100, 100])
_, extra_long_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 2000, 2000])

In [None]:
loader = torch.utils.data.DataLoader(
        extra_long_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        drop_last=True)

### Load model

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 12, 5)
        self.bn1 = nn.BatchNorm2d(12)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(12, 24, 5)
        self.bn2 = nn.BatchNorm2d(24)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(24*53*53, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 2)
    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = x.view(-1,24*53*53)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
net1 = Net()
state_dict1 = torch.load('/content/drive/MyDrive/KASHIKO/MODELS/model_2021-05-29_12:16:11_ trg_dataset1 batch_size=100 learning_rate=0.001 scheduler_step_size=5 scheduler_gamma=1 weight_decay=0 epoch_number=11 accuracy=97.6.pth')
net1.load_state_dict(state_dict1)

In [None]:
net2 = Net()
state_dict2 = torch.load('/content/drive/MyDrive/KASHIKO/MODELS/model_2021-05-29_18:41:53_ trg_dataset2 batch_size=100 learning_rate=0.001 scheduler_step_size=5 scheduler_gamma=1 weight_decay=0 epoch_number=16 accuracy=98.1.pth')
net2.load_state_dict(state_dict2)

In [None]:
net3 = Net()
state_dict3 = torch.load('/content/drive/MyDrive/KASHIKO/MODELS/model_2021-05-30_08:24:13_ trg_dataset3 batch_size=100 learning_rate=0.001 scheduler_step_size=5 scheduler_gamma=1 weight_decay=0 epoch_number=19 accuracy=98.2.pth')
net3.load_state_dict(state_dict3)

### Compute Temperatur Scaling Coefficients

In [None]:
scaled_model = ModelWithTemperature(net3)

In [None]:
scaled_model.set_temperature(loader)

### Save Temperature Factor for each model

In [None]:
temp_factor_net1 = 3.661
temp_factor_net2 = 4.289
temp_factor_net3 = 3.913

### Test performance improvement using temperature scaling

In [None]:
total_all = 0
correct_all = 0
total_sure = 0
correct_sure = 0
total_sure_temp = 0
correct_sure_temp = 0

m = nn.Softmax(dim=1)
with torch.no_grad():
  for images, labels in loader:
    net3.eval()
    out = net3(images)
    _, predicted = torch.max(out.data, 1)
    predicted_soft = m(out)
    predicted_soft_temp = m(out/temp_factor_net3)
    if np.amax(predicted_soft.numpy()) > 0.97:
      total_sure += labels.size(0)
      correct_sure += (predicted == labels).sum().item()
    if np.amax(predicted_soft_temp.numpy()) > 0.73:
      total_sure_temp += labels.size(0)
      correct_sure_temp += (predicted == labels).sum().item()
    total_all += labels.size(0)
    correct_all += (predicted == labels).sum().item()
        
test_accuracy_all = 100 * correct_all / total_all
test_accuracy_sure = 100 * correct_sure / total_sure
test_accuracy_sure_temp = 100 * correct_sure_temp / total_sure_temp

print(test_accuracy_all)
print(test_accuracy_sure)
print(test_accuracy_sure_temp)

print(100 * total_sure/total_all)
print(100 * total_sure_temp/total_all)