In [None]:
# Note - this cell should be executed only once per session
import sys, os

# in order to get top level modules and to have paths relative to repo root

if os.path.basename(os.getcwd()) != "notebooks":
    raise Exception(f"Wrong directory. Did you execute this cell twice?")
os.chdir("..")
sys.path.append(os.path.abspath("."))

%load_ext autoreload
%autoreload 2

# Class-wise and Reduced Calibration Methods

In this notebook we demonstrate two new strategies for calibrating probabilistic classifiers. These strategies act
as wrappers around any calibration algorithm and therefore are implemented as wrappers. We test the improvements
in different calibration errors due to these wrappers where the non-wrapped calibration methods serve as baselines.

The tests are performed on random forests trained on two synthetic data sets (balanced and imbalanced) as well as
on resnet20 trained on the CIFAR10 data set.

In [None]:
from collections import defaultdict

from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier

import os
import requests
import logging

from kyle.calibration.calibration_methods import *
from kyle.evaluation import EvalStats

from scipy.special import softmax

from sklearn.model_selection import StratifiedShuffleSplit, cross_val_score

import numpy as np
import matplotlib.pyplot as plt

## Helper functions for evaluation

In [None]:
DEFAULT_WRAPPERS = {
    "Baseline": lambda method_factory: method_factory(),
    "Class-wise": lambda method_factory: ClassWiseCalibration(method_factory),
    "Reduced": lambda method_factory: ConfidenceReducedCalibration(method_factory()),
    "Class-wise reduced": lambda method_factory: ClassWiseCalibration(
        lambda: ConfidenceReducedCalibration(method_factory())
    ),
}

DEFAULT_CV = 6
DEFAULT_BINS = 25

ALL_CALIBRATION_METHOD_FACTORIES = (
    # TemperatureScaling,
    BetaCalibration,
    # LogisticCalibration,
    IsotonicRegression,
    HistogramBinning,
)
ALL_METRICS = (
    "ECE",
    "cwECE",
)


def compute_score(scaler, confs: np.ndarray, labels: np.ndarray, bins, metric="ECE"):
    calibrated_confs = scaler.get_calibrated_confidences(confs)
    eval_stats = EvalStats(labels, calibrated_confs)
    if metric == "ECE":
        return eval_stats.expected_calibration_error(n_bins=bins)
    elif metric == "cwECE":
        return eval_stats.class_wise_expected_calibration_error(n_bins=bins)
    elif isinstance(metric, int):
        return eval_stats.expected_calibration_error(class_label=metric, n_bins=bins)
    else:
        raise ValueError(f"Unknown metric {metric}")


def get_scores(scaler, metric, cv, bins, confs, labels):
    scoring = lambda *args: compute_score(*args, bins=bins, metric=metric)
    return cross_val_score(scaler, confs, labels, scoring=scoring, cv=cv)


def plot_scores(wrapper_scores_dict: dict, title="", ax=None, y_lim=None):
    labels = wrapper_scores_dict.keys()
    scores_collection = wrapper_scores_dict.values()

    if ax is None:
        plt.figure(figsize=(14, 7))
        ax = plt.gca()
    ax.set_title(title)
    ax.boxplot(scores_collection, labels=labels)
    if y_lim is not None:
        ax.set_ylim(y_lim)


def evaluate_calibration_wrappers(
    method_factory,
    confidences,
    gt_labels,
    wrappers_dict=None,
    metric="ECE",
    cv=DEFAULT_CV,
    method_name=None,
    bins=DEFAULT_BINS,
    short_description=False,
):
    if method_name is None:
        method_name = method_factory.__name__
    if short_description:
        description = f"{method_name}"
    else:
        description = (
            f"Evaluating wrappers of {method_name} on metric {metric} with {bins} bins\n "
            f"CV with {cv} folds on {len(confidences)} data points."
        )
    if wrappers_dict is None:
        wrappers_dict = DEFAULT_WRAPPERS

    wrapper_scores_dict = {}
    for wrapper_name, wrapper in wrappers_dict.items():
        method = wrapper(method_factory)
        scores = get_scores(
            method, metric, cv=cv, bins=bins, confs=confidences, labels=gt_labels
        )
        wrapper_scores_dict[wrapper_name] = scores
    return wrapper_scores_dict, description


# taken such that minimum and maximum are visible in all plots
DEFAULT_Y_LIMS_DICT = {
    "ECE": (0.004, 0.032),
    "cwECE": (0.005, 0.018),
}


def perform_default_evaluation(
    confidences,
    gt_labels,
    method_factories=ALL_CALIBRATION_METHOD_FACTORIES,
    metrics=ALL_METRICS,
):
    evaluation_results = defaultdict(list)
    for metric in metrics:
        print(f"Creating evaluation for {metric}")
        for method_factory in method_factories:
            print(f"Computing scores for {method_factory.__name__}", end="\r")
            result = evaluate_calibration_wrappers(
                method_factory,
                confidences=confidences,
                gt_labels=gt_labels,
                metric=metric,
                short_description=True,
            )
            evaluation_results[metric].append(result)
    return evaluation_results


def plot_default_evaluation_results(
    evaluation_results: dict, figsize=(25, 7), y_lims_dict=None, title_addon=None
):
    if y_lims_dict is None:
        y_lims_dict = DEFAULT_Y_LIMS_DICT
    ncols = len(list(evaluation_results.values())[0])
    for metric, results in evaluation_results.items():
        fig, axes = plt.subplots(nrows=1, ncols=ncols, figsize=figsize)
        y_lim = y_lims_dict[metric]
        if ncols == 1:  # axes fails to be a list if ncols=1
            axes = [axes]
        for col, result in zip(axes, results):
            wrapper_scores_dict, description = result
            plot_scores(wrapper_scores_dict, title=description, ax=col, y_lim=y_lim)

        title = f"Evaluation with {metric} ({DEFAULT_CV} folds; {DEFAULT_BINS} bins)"
        if title_addon is not None:
            title += f"\n{title_addon}"
        fig.suptitle(title)
        plt.show()

## Part 1: Random Forest


## Load Data

In [None]:
def get_calibration_dataset(
    n_classes=5,
    weights=None,
    n_samples=30000,
    n_informative=15,
    model=RandomForestClassifier(),
):
    n_dataset_samples = 2 * n_samples
    test_size = 0.5
    X, y = make_classification(
        n_samples=n_dataset_samples,
        n_classes=n_classes,
        n_informative=n_informative,
        weights=weights,
    )
    sss = StratifiedShuffleSplit(n_splits=1, test_size=test_size)

    train_index, test_index = list(sss.split(X, y))[0]
    X_train, y_train = X[train_index], y[train_index]
    X_test, y_test = X[test_index], y[test_index]
    model.fit(X_train, y_train)
    confidences = model.predict_proba(X_test)
    y_pred = confidences.argmax(1)
    accuracy = accuracy_score(y_pred, y_test)
    print(f"Model accuracy: {accuracy}")
    return confidences, y_test

In [None]:
# this takes a while
print(f"Creating balanced dataset")
balanced_confs, balanced_gt = get_calibration_dataset()
print(f"Creating unbalanced dataset")
unbalanced_confs, unbalanced_gt = get_calibration_dataset(
    weights=(0.3, 0.1, 0.25, 0.15)
)

## Evaluating wrappers on a single calibration method

In [None]:
balanced_scores_ECE, description = evaluate_calibration_wrappers(
    HistogramBinning,
    confidences=balanced_confs,
    gt_labels=balanced_gt,
    metric="ECE",
    cv=4,
)

plot_scores(balanced_scores_ECE, title=description)
plt.show()

In [None]:
unbalanced_scores_ECE, description = evaluate_calibration_wrappers(
    TemperatureScaling,
    confidences=unbalanced_confs,
    gt_labels=unbalanced_gt,
    metric="ECE",
)

plot_scores(unbalanced_scores_ECE, title=description)
plt.show()

## Evaluating wrappers on multiple metrics and plotting next to each other

In [None]:
eval_results = perform_default_evaluation(
    confidences=balanced_confs, gt_labels=balanced_gt
)

In [None]:
plot_default_evaluation_results(eval_results, title_addon="Balanced")

In [None]:
unbalanced_eval_results = perform_default_evaluation(
    confidences=unbalanced_confs, gt_labels=unbalanced_gt
)

In [None]:
plot_default_evaluation_results(unbalanced_eval_results, title_addon="Unbalanced")

# Part 2: Resnet

Here we will repeat the evaluation of calibration methods on a neural network, specifically
on resnet20 trained on the CIFAR10 data set.

Important: in order to run the resnet part you will need the packages from `requirements-torch.txt`

In [None]:
from kyle.models.resnet import load_weights, resnet20, resnet56
from kyle.datasets import get_cifar10_dataset

In [None]:
selected_resnet = "resnet20"

weights_file_names = {
    "resnet20": "resnet20-12fca82f.th",
    "resnet56": "resnet56-4bfd9763.th",
}

models_dict = {
    "resnet20": resnet20(),
    "resnet56": resnet56(),
}


resnet_path = os.path.join("data", "artifacts", weights_file_names[selected_resnet])
cifar_10_data_path = os.path.join("data", "raw", "cifar10")
logits_save_path = os.path.join(
    "data", "processed", "cifar10", f"logits_{selected_resnet}.npy"
)

if not os.path.isfile(resnet_path):
    print(
        f"Downloading weights for {selected_resnet} to {os.path.abspath(resnet_path)}"
    )
    os.makedirs(os.path.dirname(resnet_path), exist_ok=True)
    url = f"https://github.com/akamaster/pytorch_resnet_cifar10/raw/master/pretrained_models/{weights_file_names[selected_resnet]}"
    r = requests.get(url)
    with open(resnet_path, "wb") as file:
        file.write(r.content)

resnet = models_dict[selected_resnet]
load_weights(resnet_path, resnet)
resnet.eval()


def get_cifar10_confidences():
    cifar_10_X, cifar_10_Y = get_cifar10_dataset(cifar_10_data_path)

    if os.path.isfile(logits_save_path):
        logits = np.load(logits_save_path)
    else:
        # processing all at once may not fit into ram
        batch_boundaries = range(0, len(cifar_10_X) + 1, 1000)

        logits = []
        for i in range(len(batch_boundaries) - 1):
            print(f"Processing batch {i+1}/{len(batch_boundaries)-1}", end="\r")
            lower, upper = batch_boundaries[i], batch_boundaries[i + 1]
            logits.append(resnet(cifar_10_X[lower:upper]).detach().numpy())

        logits = np.vstack(logits)
        os.makedirs(os.path.dirname(logits_save_path), exist_ok=True)
        np.save(logits_save_path, logits, allow_pickle=False)

    confidences = softmax(logits, axis=1)
    gt_labels = cifar_10_Y.numpy()
    return confidences, gt_labels

In [None]:
cifar_confs, cifar_gt = get_cifar10_confidences()

## Evaluating wrappers on a single calibration method

In [None]:
resnet_scores_ECE, description = evaluate_calibration_wrappers(
    HistogramBinning, confidences=cifar_confs, gt_labels=cifar_gt, metric="ECE", cv=4
)

plot_scores(resnet_scores_ECE, title=description)
plt.show()

In [None]:
resnet_scores_ECE, description = evaluate_calibration_wrappers(
    TemperatureScaling, confidences=cifar_confs, gt_labels=cifar_gt, metric="ECE", cv=4
)

plot_scores(resnet_scores_ECE, title=description)
plt.show()

## Evaluating wrappers on multiple metrics and plotting next to each other

In [None]:
eval_results = perform_default_evaluation(
    confidences=balanced_confs, gt_labels=balanced_gt
)

In [None]:
plot_default_evaluation_results(
    eval_results, title_addon=f"{selected_resnet} on CIFAR10"
)