In [None]:
# Note - this cell should be executed only once per session

import sys, os

# in order to get the config, it is not part of the library

if os.path.basename(os.getcwd()) != "notebooks":
    raise Exception(f"Wrong directory. Did you execute this cell twice?")
os.chdir("..")
sys.path.append(os.path.abspath("."))

%load_ext autoreload
%autoreload 2

In [None]:
import os
import requests

from kyle.calibration.calibration_methods import TemperatureScaling, ClassWiseCalibration, \
    ConfidenceReducedCalibration, BetaCalibration
from kyle.evaluation import EvalStats
from kyle.models.resnet import load_weights, resnet20, resnet56
from kyle.datasets import resnet_denormalize_transform, get_cifar10_dataset

import torch
from scipy.special import softmax

from sklearn.model_selection import StratifiedShuffleSplit, cross_val_score

import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt


## Loading Models and Data

In [None]:
# Adjust stuff in this cell

# device = 'cuda'
device = 'cpu'

selected_resnet = "resnet20"
# selected_resnet = "resnet56"

In [None]:
weights_file_names = {
    "resnet20": "resnet20-12fca82f.th",
    "resnet56": "resnet56-4bfd9763.th",
}

models_dict = {
    "resnet20": resnet20(),
    "resnet56": resnet56(),
}

resnet_path = os.path.join("data", "artifacts", weights_file_names[selected_resnet])
cifar_10_data_path = os.path.join("data", "raw", "cifar10")
logits_save_path = os.path.join("data", "processed", "cifar10", f"logits_{selected_resnet}.npy")

In [None]:
# Loading model and data

if not os.path.isfile(resnet_path):
    print(f"Downloading weights for {selected_resnet}")
    os.makedirs(os.path.basename(resnet_path), exist_ok=True)
    url = f"https://github.com/akamaster/pytorch_resnet_cifar10/raw/master/pretrained_models/{weights_file_names[selected_resnet]}"
    r = requests.get(url)
    with open(resnet_path, 'wb') as file:
        file.write(r.content)

resnet = models_dict[selected_resnet]
load_weights(resnet_path, resnet)
resnet.eval().to(device=device)

cifar_10_X, cifar_10_Y = get_cifar10_dataset(cifar_10_data_path)

sample_image = torch.moveaxis(resnet_denormalize_transform(cifar_10_X[0]), 0, 2)
plt.figure(figsize=(1.5, 1.5))
plt.title("Sample Cifar Image")
plt.imshow(sample_image)
plt.show()

## Computing (or loading) Confidence Vectors

In [None]:
if os.path.isfile(logits_save_path):
    logits = np.load(logits_save_path)
else:
    # processing all at once may not fit into ram
    batch_boundaries = range(0, len(cifar_10_X) +1, 1000)

    logits = []
    for i in range(len(batch_boundaries) - 1):
        print(f"Processing batch {i+1}/{len(batch_boundaries)-1}", end="\r")
        lower, upper = batch_boundaries[i], batch_boundaries[i+1]
        logits.append(resnet(cifar_10_X[lower:upper]).detach().numpy())

    logits = np.vstack(logits)
    os.makedirs(os.path.dirname(logits_save_path), exist_ok=True)
    np.save(logits_save_path, logits, allow_pickle=False)


confidences = softmax(logits, axis=1)
gt_labels = cifar_10_Y.numpy()

# if you run out of ram
try:
    del cifar_10_X
    del cifar_10_Y
except NameError:
    pass

## Visualizing Distribution of Confidences

In [None]:
cmap = cm.get_cmap("tab10")
bins = 50

fig, axes = plt.subplots(5, 2, figsize=(12, 18))
fig.suptitle("Distribution of confidences in predicted classes", fontsize=14)
for count, row in enumerate(axes):
    ax_left, ax_right = row
    ax_left.set_title(f"Predicted Class {count}")
    ax_right.set_title(f"Predicted Class {count + 5}")
    color_left, color_right = cmap(count), cmap(count + 5)
    max_confs_left = confidences[confidences.argmax(1) == count].max(1)
    max_confs_right = confidences[confidences.argmax(1) == count + 5].max(1)
    ax_left.hist(max_confs_left, density=True, color=color_left, bins=bins)
    ax_right.hist(max_confs_right, density=True, color=color_right, bins=bins)

plt.show()


# Temperature Scaling in Full and Reduced

## Calibration Curves and ECE

Here the initial reliability curve and ECE of the resnet

In [None]:
bins = 40 # for ECE
uncalibrated_eval_stats = EvalStats(gt_labels, confidences, bins=bins)

In [None]:
print(f"ECE full problem: {uncalibrated_eval_stats.expected_calibration_error()}")

In [None]:
uncalibrated_eval_stats.plot_reliability_curves([EvalStats.TOP_CLASS_LABEL], display_weights=True)
plt.title("Full problem")
plt.show()

## Simple Evaluation with Train/Validation Split

In [None]:
test_size = 0.8
bins = 30 # for ECE

sss = StratifiedShuffleSplit(n_splits=1, test_size=test_size)
train_index, test_index = list(sss.split(confidences, gt_labels))[0]
confidences_train, gt_labels_train = confidences[train_index], gt_labels[train_index]
confidences_test, gt_labels_test = confidences[test_index], gt_labels[test_index]

## Reduced Temp Scaling

In [None]:
t_scaling_full = TemperatureScaling()
t_scaling_binary = ConfidenceReducedCalibration()

In [None]:
t_scaling_full.fit(confidences_train, gt_labels_train)
t_scaling_binary.fit(confidences_train, gt_labels_train)

In [None]:
recalibrated_full_confs = t_scaling_full.get_calibrated_confidences(confidences_test)
recalibrated_binary_confs = t_scaling_binary.get_calibrated_confidences(confidences_test)

In [None]:
recalibrated_full_eval_stats = EvalStats(gt_labels_test, recalibrated_full_confs, bins=bins)
recalibrated_binary_eval_stats = EvalStats(gt_labels_test, recalibrated_binary_confs, bins=bins)

print(f"Reduced ECE: {recalibrated_binary_eval_stats.expected_calibration_error()}")
print(f"Full ECE: {recalibrated_full_eval_stats.expected_calibration_error()}")

In [None]:
recalibrated_full_eval_stats.plot_reliability_curves([EvalStats.TOP_CLASS_LABEL], display_weights=True)
plt.title("Full problem")
plt.show()

recalibrated_binary_eval_stats.plot_reliability_curves([EvalStats.TOP_CLASS_LABEL], display_weights=True)
plt.title("Reduced problem")
plt.show()


## Class-wise Temp Scaling

In [None]:
classwise_scaler = ClassWiseCalibration()
classwise_scaler.fit(confidences_train, gt_labels_train)

In [None]:
classwise_recalibrated_confs = classwise_scaler.get_calibrated_confidences(confidences_test)
classwise_eval_stats = EvalStats(gt_labels_test, classwise_recalibrated_confs, bins=bins)

In [None]:
classwise_eval_stats.plot_reliability_curves([EvalStats.TOP_CLASS_LABEL], display_weights=True)

In [None]:
for class_label in range(10):
    print(f"Classwise marginal for class {class_label}: {classwise_eval_stats.expected_marginal_calibration_error(class_label)}")
    print(f"Full marginal for class {class_label}: {recalibrated_full_eval_stats.expected_marginal_calibration_error(class_label)}")

In [None]:
print(f"Classwise ECE: {classwise_eval_stats.expected_calibration_error()}")
print(f"Full ECE: {recalibrated_full_eval_stats.expected_calibration_error()}")

## Cross Validation

In [None]:
temp = TemperatureScaling()
beta = BetaCalibration()
classwise_temt = ClassWiseCalibration()
reduced_temp = ConfidenceReducedCalibration()
classwise_beta = ClassWiseCalibration(BetaCalibration)
classwise_reduced_temp = ConfidenceReducedCalibration(ClassWiseCalibration())

In [None]:
cv = 3
bins = 100
class_for_marginal_error = 9

# bring in some randomness to this cell
def compute_score(scaler, confs: np.ndarray, labels: np.ndarray, metric="ECE"):
    calibrated_confs = scaler.get_calibrated_confidences(confs)
    eval_stats = EvalStats(labels, calibrated_confs, bins=bins)
    if metric == "ECE":
        return eval_stats.expected_calibration_error()
    elif isinstance(metric, int):
        return eval_stats.expected_marginal_calibration_error(metric)
    else:
        raise ValueError(f"Unknown metric {metric}")

compute_score_class = lambda *args: compute_score(*args, metric=class_for_marginal_error)

temp_scores = cross_val_score(temp, confidences, gt_labels, scoring=compute_score, cv=cv)
beta_scores = cross_val_score(beta, confidences, gt_labels, scoring=compute_score, cv=cv)

cw_temp_scores = cross_val_score(classwise_temt, confidences, gt_labels, scoring=compute_score, cv=cv)
cw_beta_scores = cross_val_score(classwise_beta, confidences, gt_labels, scoring=compute_score, cv=cv)
reduced_temp_scores = cross_val_score(reduced_temp, confidences, gt_labels, scoring=compute_score, cv=cv)
cw_reduced_temp_scores = cross_val_score(classwise_reduced_temp, confidences, gt_labels, scoring=compute_score, cv=cv)

marginal_temp_scores = cross_val_score(temp, confidences, gt_labels, scoring=compute_score_class, cv=cv)
marginal_beta_scores = cross_val_score(beta, confidences, gt_labels, scoring=compute_score_class, cv=cv)
marginal_cw_temp_scores = cross_val_score(classwise_temt, confidences, gt_labels, scoring=compute_score_class, cv=cv)
marginal_cw_beta_scores = cross_val_score(classwise_beta, confidences, gt_labels, scoring=compute_score_class, cv=cv)

In [None]:
print(f"CV with {cv} folds on {len(confidences)} data points. \n"
      f"Scores computed with {bins} bins", end="\n\n")

def print_results(scores, name):
    print(f"scores for: {name}")
    print(scores)
    print("Mean and std")
    print(scores.mean(), scores.std(), end="\n\n\n")

print_results(temp_scores, "Baseline 1 - ECE, temperature")
print_results(beta_scores, "Baseline 2 - ECE, beta")
print_results(cw_temp_scores, "ECE, Class-wise temperature")
print_results(cw_beta_scores, "ECE, Class-wise beta")
print_results(reduced_temp_scores, "ECE, Reduced temperature")
print_results(cw_reduced_temp_scores, "ECE, Class-wise Reduced temperature")

print_results(marginal_temp_scores, "Baseline 3 -Marginal, temperature")
print_results(marginal_beta_scores, "Baseline 4 -Marginal, beta")
print_results(marginal_cw_temp_scores, "Marginal, Class-wise temperature")
print_results(marginal_cw_beta_scores, "Marginal, Class-wise beta")