In [None]:
# Note - this cell should be executed only once per session

import sys, os

# in order to get the config, it is not part of the library

if os.path.basename(os.getcwd()) != "notebooks":
    raise Exception(f"Wrong directory. Did you execute this cell twice?")
os.chdir("..")
sys.path.append(os.path.abspath("."))

%load_ext autoreload
%autoreload 2

In [None]:
from copy import copy
from kyle.calibration.calibration_methods import TemperatureScaling, ClassWiseCalibration, \
    ConfidenceReducedCalibration, BetaCalibration, BaseCalibrationMethod, IsotonicRegression, get_binary_classification_data
from kyle.evaluation import EvalStats

from sklearn.model_selection import StratifiedShuffleSplit, cross_val_score
from sklearn.datasets import load_iris, load_breast_cancer, make_classification
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier

import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt


## Loading Models and Data

In [None]:
n_classes = 5

dataset = make_classification(n_samples=60000, n_classes=n_classes, n_informative=15)

X, y = dataset
# X, y = dataset["data"], dataset["target"]

y.shape

In [None]:
test_size = 0.5
sss = StratifiedShuffleSplit(n_splits=1, test_size=test_size)

train_index, test_index = list(sss.split(X, y))[0]
X_train, y_train = X[train_index], y[train_index]
X_test, y_test = X[test_index], y[test_index]

In [None]:
model = RandomForestClassifier()
model.fit(X_train, y_train)

confidences = model.predict_proba(X_test)
y_pred = confidences.argmax(1)
accuracy = accuracy_score(y_pred, y_test)
print(accuracy)

In [None]:
# Loading model and data

confidences = confidences
gt_labels = y_test


## Visualizing Distribution of Confidences

In [None]:
cmap = cm.get_cmap("tab10")
bins = 50

fig, axes = plt.subplots(n_classes, figsize=(5, 5))
fig.suptitle("Distribution of confidences in predicted classes", fontsize=14)
for count, row in enumerate(axes):
    row.set_title(f"Predicted Class {count}")
    color_left, color_right = cmap(count), cmap(count + 5)
    max_confs = confidences[confidences.argmax(1) == count].max(1)
    row.hist(max_confs, density=True, color=color_left, bins=bins)

plt.show()


# Temperature Scaling in Normal, Reduced adn Class-wise

## Simple Evaluation with Train/Validation Split

In [None]:
test_size = 0.5
bins = 20 # for ECE

sss = StratifiedShuffleSplit(n_splits=1, test_size=test_size)
train_index, test_index = list(sss.split(confidences, gt_labels))[0]
confidences_train, gt_labels_train = confidences[train_index], gt_labels[train_index]
confidences_test, gt_labels_test = confidences[test_index], gt_labels[test_index]

Here the initial reliability curve and ECE of the resnet

In [None]:
uncalibrated_eval_stats = EvalStats(gt_labels_test, confidences_test, bins=bins)

In [None]:
print(f"ECE uncalibrated: {uncalibrated_eval_stats.expected_calibration_error()}")
print(f"Marginal uncalibrated: {uncalibrated_eval_stats.expected_marginal_calibration_error(1)}")


In [None]:
uncalibrated_eval_stats.plot_reliability_curves([EvalStats.TOP_CLASS_LABEL], display_weights=True)
plt.title("Uncalibrated reliabilities")
plt.show()

## Reduced Temp Scaling

In [None]:
t_scaling_full = TemperatureScaling()
t_scaling_binary = ConfidenceReducedCalibration()

In [None]:
t_scaling_full.fit(confidences_train, gt_labels_train)
t_scaling_binary.fit(confidences_train, gt_labels_train)

In [None]:
recalibrated_full_confs = t_scaling_full.get_calibrated_confidences(confidences_test)
recalibrated_reduced_confs = t_scaling_binary.get_calibrated_confidences(confidences_test)

In [None]:
recalibrated_full_eval_stats = EvalStats(gt_labels_test, recalibrated_full_confs, bins=bins)
recalibrated_binary_eval_stats = EvalStats(gt_labels_test, recalibrated_reduced_confs, bins=bins)

print(f"Temp Scaling ECE: {recalibrated_full_eval_stats.expected_calibration_error()}")
print(f"Reduced Temp Scaling ECE: {recalibrated_binary_eval_stats.expected_calibration_error()}")

In [None]:
bin_confs, bin_gt  = get_binary_classification_data(recalibrated_reduced_confs, gt_labels_test)
bin_eval_stats = EvalStats(bin_gt, bin_confs, bins=bins)

In [None]:
bin_eval_stats.plot_reliability_curves([0])
plt.show()

In [None]:
recalibrated_full_eval_stats.plot_reliability_curves([EvalStats.TOP_CLASS_LABEL], display_weights=True)
plt.title("Temp scaling")
plt.show()

recalibrated_binary_eval_stats.plot_reliability_curves([EvalStats.TOP_CLASS_LABEL], display_weights=True)
plt.title("Reduced temp scaling")
plt.show()


## Class-wise Temp Scaling

In [None]:
%%capture

classwise_scaler = ClassWiseCalibration()
classwise_scaler.fit(confidences_train, gt_labels_train)

In [None]:
classwise_recalibrated_confs = classwise_scaler.get_calibrated_confidences(confidences_test)
classwise_eval_stats = EvalStats(gt_labels_test, classwise_recalibrated_confs, bins=bins)

In [None]:
classwise_eval_stats.plot_reliability_curves([EvalStats.TOP_CLASS_LABEL], display_weights=True)
plt.title("Class-wise Calibrated")
plt.show()

In [None]:
print(f"Class-wise Temp Scaling ECE: {classwise_eval_stats.expected_calibration_error()}")
print(f"Temp Scaling ECE: {recalibrated_full_eval_stats.expected_calibration_error()}")

print(f"Class-wise Temp Scaling cwECE: {classwise_eval_stats.class_wise_expected_calibration_error()}")
print(f"Temp Scaling cwECE: {recalibrated_full_eval_stats.class_wise_expected_calibration_error()}")

## Cross Validation

In [None]:
DEFAULT_WRAPPERS = {
    "Baseline": lambda method_factory: method_factory(),
    "Class-wise": lambda method_factory: ClassWiseCalibration(method_factory),
    "Reduced": lambda method_factory: ConfidenceReducedCalibration(method_factory()),
    "Class-wise reduced": lambda method_factory:
                            ClassWiseCalibration(lambda : ConfidenceReducedCalibration(method_factory())),
}

DEFAULT_CV = 5
DEFAULT_BINS = 20

ALL_CALIBRATION_METHOD_FACTORIES = (TemperatureScaling, BetaCalibration, IsotonicRegression)
ALL_METRICS = ("ECE", "cwECE")

def compute_score(scaler, confs: np.ndarray, labels: np.ndarray, bins, metric="ECE"):
    calibrated_confs = scaler.get_calibrated_confidences(confs)
    eval_stats = EvalStats(labels, calibrated_confs, bins=bins)
    if metric == "ECE":
        return eval_stats.expected_calibration_error()
    elif metric == "cwECE":
        return eval_stats.class_wise_expected_calibration_error()
    elif isinstance(metric, int):
        return eval_stats.expected_marginal_calibration_error(metric)
    else:
        raise ValueError(f"Unknown metric {metric}")

def get_scores(scaler, metric, cv, bins):
    scoring = lambda *args: compute_score(*args, bins=bins, metric=metric)
    return cross_val_score(scaler, confidences, gt_labels, scoring=scoring, cv=cv)

def plot_scores(labels, scores_collection, title="", ax=None, y_lim=None):
    if ax is None:
        fig = plt.figure(figsize=(14,7))
        ax = fig.axes
    ax.set_title(title)
    ax.boxplot(scores_collection, labels=labels)
    if y_lim is not None:
        ax.set_ylim(y_lim)
    # Does not work this way, has to be set somewhere else. Why do we need this?
    # ax.set_xticks(rotation=70)


def evaluate_calibration_wrappers(method_factory, wrappers_dict=None, metric="ECE", cv=DEFAULT_CV, method_name=None,
        bins=DEFAULT_BINS, ax=None, short_title=False, y_lim=None):
    if method_name is None:
        method_name = method_factory.__name__
    if short_title:
        plot_title = f"{method_name}"
    else:
        plot_title = f"Evaluating wrappers of {method_name} on metric {metric} with {bins} bins\n " \
                     f"CV with {cv} folds on {len(confidences)} data points."
    if wrappers_dict is None:
        wrappers_dict = DEFAULT_WRAPPERS

    labels = []
    scores_collection = []
    for label, wrapper in wrappers_dict.items():
        labels.append(label)
        method = wrapper(method_factory)
        scores = get_scores(method, metric, cv=cv, bins=bins)
        scores_collection.append(scores)
    plot_scores(labels, scores_collection, title=plot_title, ax=ax, y_lim=y_lim)
    return labels, scores_collection # just in case we wanna do more than plotting

Y_LIM = (0.001, 0.22) # taken such that minimum and maximum are visible in all plots

def perform_default_evaluation(method_factories=ALL_CALIBRATION_METHOD_FACTORIES, metrics=ALL_METRICS,
       figsize=(20, 7), y_lim=Y_LIM):
    """This may take a while with all methods and metrics"""
    for metric in metrics:
        fig, axes = plt.subplots(nrows=1, ncols=len(method_factories), figsize=figsize)
        if len(method_factories) == 1: # axes fails to be a list if ncols=1
            axes = [axes]
        for col, method_factory in zip(axes, method_factories):
            evaluate_calibration_wrappers(method_factory, metric=metric, ax=col, short_title=True, y_lim=y_lim)
        fig.suptitle(f"Evaluation with {metric} ({DEFAULT_CV} folds; {DEFAULT_BINS} bins)")
        fig.show()

In [None]:
perform_default_evaluation()