# ECE plot with the errorbars given by the confidence distribution in the bins

In [None]:
import os
os.getcwd()

In [None]:
import argparse
import yaml

import torch
import pycalib
from laplace import Laplace

import utils.data_utils as du
import utils.wilds_utils as wu
import utils.utils as util
from utils.test import test
from marglik_training.train_marglik import get_backend

# import warnings
# warnings.filterwarnings('ignore')

from argparse import Namespace

from tqdm import tqdm

import matplotlib.pyplot as plt

from copy import deepcopy

from random import randint

import numpy as np

In [None]:
from tueplots import bundles



# Inspired by bundles.neurips2023(), but adapting font sizes for pt12 standard

settings_dict = {'text.usetex': True,
                 'font.family': 'serif',
                 'text.latex.preamble': '\\renewcommand{\\rmdefault}{ptm}\\renewcommand{\\sfdefault}{phv}',
                 'figure.figsize': (5.5, 3.399186938124422),
                 'figure.constrained_layout.use': True,
                 'figure.autolayout': False,
                 'savefig.bbox': 'tight',
                 'savefig.pad_inches': 0.015,
                 'font.size': 10,
                 'axes.labelsize': 10,
                 'legend.fontsize': 8,
                 'xtick.labelsize': 8,
                 'ytick.labelsize': 8,
                 'axes.titlesize': 10,
                 'figure.dpi': 300}


plt.rcParams.update(settings_dict)


# Can use colors from bundles.rgb.
#     tue_blue
#     tue_brown
#     tue_dark
#     tue_darkblue
#     tue_darkgreen
#     tue_gold
#     tue_gray
#     tue_green
#     tue_lightblue
#     tue_lightgold
#     tue_lightgreen
#     tue_lightorange
#     tue_mauve
#     tue_ocre
#     tue_orange
#     tue_red
#     tue_violet

In [None]:
def batch_cov(points):
    B, N, D = points.size()
    mean = points.mean(dim=1).unsqueeze(1)
    diffs = (points - mean).reshape(B * N, D)
    prods = torch.bmm(diffs.unsqueeze(2), diffs.unsqueeze(1)).reshape(B, N, D, D)
    bcov = prods.sum(dim=1) / (N - 1)  # Unbiased estimate
    return bcov  # (B, D, D)


In [None]:
def normal_samples(mean, var, n_samples, generator=None):
    """Produce samples from a batch of Normal distributions either parameterized
    by a diagonal or full covariance given by `var`.

    Parameters
    ----------
    mean : torch.Tensor
        `(batch_size, output_dim)`
    var : torch.Tensor
        (co)variance of the Normal distribution
        `(batch_size, output_dim, output_dim)` or `(batch_size, output_dim)`
    generator : torch.Generator
        random number generator
    """
    assert mean.ndim == 2, 'Invalid input shape of mean, should be 2-dimensional.'
    _, output_dim = mean.shape
    randn_samples = torch.randn((output_dim, n_samples), device=mean.device, 
                                dtype=mean.dtype, generator=generator)
    
    if mean.shape == var.shape:
        # diagonal covariance
        scaled_samples = var.sqrt().unsqueeze(-1) * randn_samples.unsqueeze(0)
        return (mean.unsqueeze(-1) + scaled_samples).permute((2, 0, 1))
    elif mean.shape == var.shape[:2] and var.shape[-1] == mean.shape[1]:
        # full covariance
        scale = torch.linalg.cholesky(var)
        scaled_samples = torch.matmul(scale, randn_samples.unsqueeze(0))  # expand batch dim
        return (mean.unsqueeze(-1) + scaled_samples).permute((2, 0, 1))
    else:
        raise ValueError('Invalid input shapes.')



In [None]:
def calculate_confs_preds_variances_ypreds(f_mu, f_var, y_true, n_samples = 10000, generator = None, batchsize = 128):
    # For all images, calculate the conf and covariance
    # To do this sample from distribution

    confs_list = []
    preds_list = []
    variances_list = []
    y_preds_list = []

    s_list = list(range(0, y_true.shape[0] + batchsize, batchsize))
    # s_list = list(range(0, 1000, batchsize))
    for start, stop in tqdm(zip(s_list[:-1], s_list[1:])):
        f_mu_now = f_mu[start:stop]
        f_var_now = f_var[start:stop]
        
        f_samples = normal_samples(f_mu_now, f_var_now, n_samples, generator)
        y_prob = torch.softmax(f_samples, dim=-1)

        covariances = batch_cov(y_prob.permute(1,0,2))

        y_pred = y_prob.mean(dim=0)

        confs, preds = torch.max(y_pred, 1)

        variances = torch.tensor([c[preds[i], preds[i]] for i, c in enumerate(covariances)])

        confs_list.append(confs)
        preds_list.append(preds)
        variances_list.append(variances)
        y_preds_list.append(y_pred)


    confs_list = torch.cat(confs_list)
    preds_list = torch.cat(preds_list)
    variances_list = torch.cat(variances_list)
    y_preds_list = torch.cat(y_preds_list)

    return confs_list, preds_list, variances_list, y_preds_list

In [None]:
def calculate_ece(outputs, labels, n_bins=10):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).
    The input to this loss are the model output softmax scores.
    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:
    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
    We then return a weighted average of the gaps, based on the number
    of samples in each bin
    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """

    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    softmaxes = outputs
    confidences, predictions = torch.max(softmaxes, 1)
    accuracies = predictions.eq(labels)

    ece = torch.zeros(1, device=outputs.device)
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # Calculated |confidence - accuracy| in each bin
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
    return ece.item()

### Plots for PR

In [None]:
import os
print(os.getcwd())
import argparse
import yaml

import torch
import pycalib
from laplace import Laplace

import utils.data_utils as du
import utils.wilds_utils as wu
import utils.utils as util
from utils.test import test
from marglik_training.train_marglik import get_backend

# import warnings
# warnings.filterwarnings('ignore')

from argparse import Namespace

from tqdm import tqdm

import matplotlib.pyplot as plt

from copy import deepcopy

from random import randint

import numpy as np

In [None]:
def batch_cov(points):
    B, N, D = points.size()
    mean = points.mean(dim=1).unsqueeze(1)
    diffs = (points - mean).reshape(B * N, D)
    prods = torch.bmm(diffs.unsqueeze(2), diffs.unsqueeze(1)).reshape(B, N, D, D)
    bcov = prods.sum(dim=1) / (N - 1)  # Unbiased estimate
    return bcov  # (B, D, D)



In [None]:
def make_reliability_diagram_with_conf_uncertainty(y_pred, variances, labels, n_bins=40, title_name=""):
    """
    outputs - a torch tensor (size n x num_classes) with the output of a model (after Softmax)
    labels - a torch tensor (size n) with the labels
    """

    # outputs = y_prob.mean(axis=0)
    outputs = y_pred

    softmaxes = outputs
    confidences, predictions = softmaxes.max(1)
    accuracies = torch.eq(predictions, labels)
    overall_accuracy = (predictions==labels).sum().item()/len(labels)
    
    # Reliability diagram
    bins = torch.linspace(0, 1, n_bins + 1)
    width = 1.0 / n_bins
    bin_centers = np.linspace(0, 1.0 - width, n_bins) + width / 2
    bin_indices = [confidences.ge(bin_lower) * confidences.lt(bin_upper) for bin_lower, bin_upper in zip(bins[:-1], bins[1:])]
    



    # # mean_conf in each bin (with calculation from the individual variance)
    # # To prevent crashing, do it in batches:
    # s_list = list(range(0, y_prob.shape[1] + 10000, 5000))
    # covariances = torch.cat([batch_cov(y_prob[:, start:stop].permute(1,0,2)) for start, stop in zip(s_list[:-1], s_list[1:])])
    # max_class_variances = torch.tensor([c[predictions[i], predictions[i]] for i, c in enumerate(covariances)])
    # # mean_variance = torch.tensor([c[preds[i], preds[i]] for i, c in enumerate(covariances)]).mean().item()

    max_class_variances = variances

    bin_mean_variances = np.array([max_class_variances[bin_index].mean() for bin_index in bin_indices])
    bin_counts = np.array([max_class_variances[bin_index].shape[0] for bin_index in bin_indices])
    bin_counts = np.maximum(bin_counts, 1) # Solve division by zero error in empty bins
    bin_mean_variances = bin_mean_variances * (1 / bin_counts)

    bin_mean_2stds = np.sqrt(bin_mean_variances) * 2

    # # p-value of an accuracy as extreme or more extreme than the observed one in the respective bin
    # idx = predictions.unsqueeze(0).repeat(y_prob.shape[0], 1).unsqueeze(2)
    # selected_class_confs = torch.gather(y_prob, 2, idx).squeeze()
    # bin_p_values = []
    # for bin_index in bin_indices:
    #     bin_y_prob = y_prob[:, bin_index, :]
    #     # check for sidedness:
    #     bin_acc = torch.mean(accuracies[bin_index].float())
    #     bin_conf = torch.mean(confidences[bin_index].float())
    #     bin_selected_class_confs = selected_class_confs[:, bin_index]
    #     bin_mean_confs_dist = bin_selected_class_confs.mean(axis=1)
    #     if bin_acc < bin_conf:
    #         bin_p_value = torch.mean((bin_mean_confs_dist < bin_acc).float())
    #     else:
    #         bin_p_value = torch.mean((bin_mean_confs_dist > bin_acc).float())
    #     bin_p_values.append(bin_p_value)
    # bin_p_values = np.array(bin_p_values)
    # avg_p_value = bin_p_values.mean()    



    bin_corrects = np.array([ torch.mean(accuracies[bin_index].float()) for bin_index in bin_indices])
    bin_scores = np.array([ torch.mean(confidences[bin_index].float()) for bin_index in bin_indices])
    bin_corrects = np.nan_to_num(bin_corrects)
    bin_scores = np.nan_to_num(bin_scores)
    
    # plt.figure(0, figsize=(8, 8))
    # figsize = (plt.rcParams["figure.figsize"][1], plt.rcParams["figure.figsize"][1])
    # figsize = (4.5, 4.5)
    figsize = (2.75, 2.75)
    plt.figure(0, figsize=figsize)
    gap = np.array(bin_scores - bin_corrects)
    
    confs = plt.bar(bin_centers, bin_corrects, color=[0, 0, 1], width=width, ec='black')
    bin_corrects = np.nan_to_num(np.array([bin_correct for bin_correct in bin_corrects]))
    gaps = plt.bar(bin_centers, gap, bottom=bin_corrects, color=[1, 0.7, 0.7], alpha=0.5, width=width, hatch='//', edgecolor='r')
    # errorbars = plt.errorbar(bin_centers, bin_corrects, fmt=".", yerr=bin_mean_2stds)
    # errorbars = plt.errorbar(bin_centers, bin_centers, fmt=".", yerr=bin_mean_2stds)
    errorbars = plt.errorbar(bin_centers, bin_centers, fmt=".", xerr=bin_mean_2stds)
    
    plt.plot([0, 1], [0, 1], '--', color='gray')
    # plt.legend([confs, gaps, errorbars], ['Accuracy', 'Gap', '2 stds of the confidence'], loc='upper left', fontsize='x-large')
    plt.legend([confs, gaps, errorbars], ['Accuracy', 'Gap', r'$2\sigma$ of the Confidence'], loc='upper left')

    ece = calculate_ece(outputs, labels)

    # Clean up
    # bbox_props = dict(boxstyle="square", fc="lightgrey", ec="gray", lw=1.5)
    # plt.text(0.95, 0.05, "ECE: {:.4f}".format(ece), ha="right", va="bottom", size=plt.rcParams["font.size"], weight = 'normal', bbox=bbox_props)
    # plt.text(0.17, 0.64, "AVG p-value: {:.8f}\n(for an accuracy as\n extreme or more\n extreme as the one\n observed wrt\n the mean_confidence\n distribution in each bin)".format(avg_p_value), ha="center", va="center", size=14, weight = 'normal', bbox=bbox_props)

    # plt.title(f"{title_name}", size=22)
    # plt.ylabel("Accuracy",  size=18)
    # plt.xlabel("Confidence",  size=18)
    plt.title(f"{title_name}")
    plt.ylabel("Accuracy")
    plt.xlabel("Confidence")
    plt.xlim(0,1)
    plt.ylim(0,1)
    # plt.savefig(f"plotting_uncertainty_pr/reliability_plots/{title_name}.png")
    # plt.show()
    # return ece

# Camelyon17

In [None]:
reliabilityDiagramsSavedir = "./results/img/Results/ReliabilityPlotsWithUncertainty"
if not os.path.exists(reliabilityDiagramsSavedir):
    os.makedirs(reliabilityDiagramsSavedir)

In [None]:
distribution_directories = ['./results/predictive_distributions/camelyon17/',
                        './results/predictive_distributions/camelyon17_ts/',
                        './results/predictive_distributions/camelyon17_scaling/',
                        # './results/predictive_distributions/camelyon17_diagadd_fitted/',
                        # './results/predictive_distributions/camelyon17_diagscaling_fitted/',
                        # './results/predictive_distributions/camelyon17_scaling_only_simple_scaling_fitted/',
                        './results/predictive_distributions/camelyon17_ts_and_scaling_fitted/',
                        # './results/predictive_distributions/camelyon17_ts_and_scaling_fitted_on_ood_val/',
                        # './results/predictive_distributions/camelyon17_ts_and_scaling_fitted_model6/',
                        # './results/predictive_distributions/camelyon17_ts_and_scaling_fitted_on_ood_val_model6/'
]

title_names = ["LLLA", "LLLA+WITS", "LLLA+CVS", "LLLA+WITS+CVS"]

for DISTRIBUTIONS_DIRECTORY, title_name in zip(distribution_directories, title_names):
    for DATASET in ['camelyon17-id', 'camelyon17-ood']:
        print("### ", DISTRIBUTIONS_DIRECTORY, " - ", DATASET)
        y_true = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_true_" + DATASET + ".pt"))
        # y_prob = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_prob_" + DATASET + ".pt"))

        f_mu = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_mu_" + DATASET + ".pt"))
        f_var = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_var_" + DATASET + ".pt"))

        confs, preds, variances, y_pred = calculate_confs_preds_variances_ypreds(f_mu, f_var, y_true)

        DatasetTranslationDict = {'camelyon17-id': 'WILDS-Camelyon17 (ID)', 'camelyon17-ood': 'WILDS-Camelyon17 (OOD)', 'amazon-id': 'WILDS-Amazon (ID)', 'amazon-ood': 'WILDS-Amazon (OOD)', 'SkinLesions-id': 'SkinLesions (ID)', 'SkinLesions-ood': 'SkinLesions (OOD)'}
        title = f'{title_name} on {DatasetTranslationDict[DATASET]}'
        make_reliability_diagram_with_conf_uncertainty(y_pred, variances, y_true, title_name=title)
        plt.savefig(os.path.join(reliabilityDiagramsSavedir, f'{title_name}_{DATASET}.pdf'))
        plt.show()




# Camelyon17 - ResNet50

In [None]:
distribution_directories = ['./results/predictive_distributions/camelyon17_resnet50/', # TODO or use:
                            # './results/predictive_distributions/camelyon17_ts/', #TODO or use this one?
                            './results/predictive_distributions/camelyon17_resnet50_ts_and_scaling_fitted']

title_names = ["LLLA", "LLLA+WITS+CVS"]


for DISTRIBUTIONS_DIRECTORY, title_name in zip(distribution_directories, title_names):
    for DATASET in ['camelyon17-id', 'camelyon17-ood']:
        print("### ", DISTRIBUTIONS_DIRECTORY, " - ", DATASET)
        y_true = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_true_" + DATASET + ".pt"))
        # y_prob = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_prob_" + DATASET + ".pt"))

        f_mu = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_mu_" + DATASET + ".pt"))
        f_var = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_var_" + DATASET + ".pt"))

        confs, preds, variances, y_pred = calculate_confs_preds_variances_ypreds(f_mu, f_var, y_true)

        DatasetTranslationDict = {'camelyon17-id': 'WILDS-Camelyon17 (ID)', 'camelyon17-ood': 'WILDS-Camelyon17 (OOD)', 'amazon-id': 'WILDS-Amazon (ID)', 'amazon-ood': 'WILDS-Amazon (OOD)', 'SkinLesions-id': 'SkinLesions (ID)', 'SkinLesions-ood': 'SkinLesions (OOD)'}
        DatasetTranslationDict = {'camelyon17-id': 'Camelyon17 (ID)', 'camelyon17-ood': 'Camelyon17 (OOD)', 'amazon-id': 'WILDS-Amazon (ID)', 'amazon-ood': 'WILDS-Amazon (OOD)', 'SkinLesions-id': 'SkinLesions (ID)', 'SkinLesions-ood': 'SkinLesions (OOD)'}
        title = f'{title_name} with ResNet50 on {DatasetTranslationDict[DATASET]}'
        make_reliability_diagram_with_conf_uncertainty(y_pred, variances, y_true, title_name=title)
        plt.savefig(os.path.join(reliabilityDiagramsSavedir, f'{title_name}_{DATASET}_resnet50.pdf'))
        plt.show()




# Amazon

In [None]:
distribution_directories = ['./results/predictive_distributions/amazon/',
                            './results/predictive_distributions/amazon_scaling/',
                            './results/predictive_distributions/amazon_ts/',
                            './results/predictive_distributions/amazon_ts_and_scaling_fitted/',
]

title_names = ["LLLA", "LLLA+CVS", "LLLA+WITS", "LLLA+WITS+CVS"]


for DISTRIBUTIONS_DIRECTORY, title_name in zip(distribution_directories, title_names):
    for DATASET in ['amazon-id', 'amazon-ood']:
        print("### ", DISTRIBUTIONS_DIRECTORY, " - ", DATASET)
        y_true = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_true_" + DATASET + ".pt"))
        # y_prob = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_prob_" + DATASET + ".pt"))

        f_mu = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_mu_" + DATASET + ".pt"))
        f_var = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_var_" + DATASET + ".pt"))

        confs, preds, variances, y_pred = calculate_confs_preds_variances_ypreds(f_mu, f_var, y_true)

        DatasetTranslationDict = {'camelyon17-id': 'WILDS-Camelyon17 (ID)', 'camelyon17-ood': 'WILDS-Camelyon17 (OOD)', 'amazon-id': 'WILDS-Amazon (ID)', 'amazon-ood': 'WILDS-Amazon (OOD)', 'SkinLesions-id': 'SkinLesions (ID)', 'SkinLesions-ood': 'SkinLesions (OOD)'}
        title = f'{title_name} on {DatasetTranslationDict[DATASET]}'
        make_reliability_diagram_with_conf_uncertainty(y_pred, variances, y_true, title_name=title)
        plt.savefig(os.path.join(reliabilityDiagramsSavedir, f'{title_name}_{DATASET}.pdf'))
        plt.show()



# SkinLesions

In [None]:
distribution_directories = ['./results/predictive_distributions/SkinLesions/',
                            './results/predictive_distributions/SkinLesions_ts/',
                            './results/predictive_distributions/SkinLesions_scaling/',
                            './results/predictive_distributions/SkinLesions_ts_and_scaling_fitted/',
                            ]

title_names = ["LLLA", "LLLA+WITS", "LLLA+CVS", "LLLA+WITS+CVS"]

for DISTRIBUTIONS_DIRECTORY, title_name in zip(distribution_directories, title_names):
    for DATASET in ['SkinLesions-id', 'SkinLesions-ood']:
        print("### ", DISTRIBUTIONS_DIRECTORY, " - ", DATASET)
        y_true = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_true_" + DATASET + ".pt"))
        # y_prob = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_prob_" + DATASET + ".pt"))

        f_mu = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_mu_" + DATASET + ".pt"))
        f_var = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_var_" + DATASET + ".pt"))

        confs, preds, variances, y_pred = calculate_confs_preds_variances_ypreds(f_mu, f_var, y_true)

        DatasetTranslationDict = {'camelyon17-id': 'WILDS-Camelyon17 (ID)', 'camelyon17-ood': 'WILDS-Camelyon17 (OOD)', 'amazon-id': 'WILDS-Amazon (ID)', 'amazon-ood': 'WILDS-Amazon (OOD)', 'SkinLesions-id': 'SkinLesions (ID)', 'SkinLesions-ood': 'SkinLesions (OOD)'}
        title = f'{title_name} on {DatasetTranslationDict[DATASET]}'
        make_reliability_diagram_with_conf_uncertainty(y_pred, variances, y_true, title_name=title)
        plt.savefig(os.path.join(reliabilityDiagramsSavedir, f'{title_name}_{DATASET}.pdf'))
        plt.show()




# Plot confidence-accuracy gap by size of uncertainty

### Nice Plots PR

In [None]:
import os
print(os.getcwd())
import argparse
import yaml

import torch
import pycalib
from laplace import Laplace

import utils.data_utils as du
import utils.wilds_utils as wu
import utils.utils as util
from utils.test import test
from marglik_training.train_marglik import get_backend

# import warnings
# warnings.filterwarnings('ignore')

from argparse import Namespace

from tqdm import tqdm

import matplotlib.pyplot as plt

from copy import deepcopy

from random import randint

import numpy as np

In [None]:
# import matplotlib
# matplotlib.rcParams["figure.dpi"] = 300

In [None]:
def batch_cov(points):
    B, N, D = points.size()
    mean = points.mean(dim=1).unsqueeze(1)
    diffs = (points - mean).reshape(B * N, D)
    prods = torch.bmm(diffs.unsqueeze(2), diffs.unsqueeze(1)).reshape(B, N, D, D)
    bcov = prods.sum(dim=1) / (N - 1)  # Unbiased estimate
    return bcov  # (B, D, D)



In [None]:
def plot_conf_acc_gap_by_uncertainty_to_ax(ax, y_pred, variances, labels, n_bins=10, marker='o'):
    """
    outputs - a torch tensor (size n x num_classes) with the output of a model (after Softmax)
    labels - a torch tensor (size n) with the labels
    """

    # outputs = y_prob.mean(axis=0)
    outputs = y_pred

    softmaxes = outputs
    confidences, predictions = softmaxes.max(1)
    accuracies = torch.eq(predictions, labels)
    overall_accuracy = (predictions==labels).sum().item()/len(labels)
    
    # Reliability diagram
    bins = torch.linspace(0, 1, n_bins + 1)
    width = 1.0 / n_bins
    bin_centers = np.linspace(0, 1.0 - width, n_bins) + width / 2
    bin_indices = [confidences.ge(bin_lower) * confidences.lt(bin_upper) for bin_lower, bin_upper in zip(bins[:-1], bins[1:])]
    


    # # mean_conf in each bin (with calculation from the individual variance)
    # # To prevent crashing, do it in batches:
    # s_list = list(range(0, y_prob.shape[1] + 10000, 5000))
    # covariances = torch.cat([batch_cov(y_prob[:, start:stop].permute(1,0,2)) for start, stop in zip(s_list[:-1], s_list[1:])])
    # max_class_variances = torch.tensor([c[predictions[i], predictions[i]] for i, c in enumerate(covariances)])
    # # mean_variance = torch.tensor([c[preds[i], preds[i]] for i, c in enumerate(covariances)]).mean().item()

    max_class_variances = variances

    bin_mean_variances = np.array([max_class_variances[bin_index].mean() for bin_index in bin_indices])
    bin_counts = np.array([max_class_variances[bin_index].shape[0] for bin_index in bin_indices])
    bin_counts = np.maximum(bin_counts, 1) # Solve division by zero error in empty bins
    bin_mean_variances = bin_mean_variances * (1 / bin_counts)

    bin_mean_2stds = np.sqrt(bin_mean_variances) * 2

    
    bin_corrects = np.array([ torch.mean(accuracies[bin_index].float()) for bin_index in bin_indices])
    bin_scores = np.array([ torch.mean(confidences[bin_index].float()) for bin_index in bin_indices])
    bin_corrects = np.nan_to_num(bin_corrects)
    bin_scores = np.nan_to_num(bin_scores)
    
    # plt.figure(0, figsize=(8, 8))
    gap = np.array(bin_scores - bin_corrects)

    gap = np.abs(gap)
    
    # confs = plt.bar(bin_centers, bin_corrects, color=[0, 0, 1], width=width, ec='black')
    # bin_corrects = np.nan_to_num(np.array([bin_correct for bin_correct in bin_corrects]))


#     colors = plt.get_cmap('viridis')(bin_counts/bin_counts.max())
    # colors = bundles.rgb.tue_blue # Plot without color
    # colors = bundles.rgb.tue_violet # Plot without color
    colors = bundles.rgb.tue_darkblue # Plot without color

    # print(bin_counts)
    # ax.plot(bin_mean_2stds, gap, "o") # TODO changed
    ax.scatter(bin_mean_2stds, gap, c = colors, marker=marker)
    # ax.scatter(bin_mean_variances, gap, c = colors)

    # # Calculate Correlation
    # empty_ids = np.logical_not(np.isnan(bin_mean_2stds))
    # corr = np.corrcoef(np.vstack([gap[empty_ids], bin_mean_2stds[empty_ids]]))[0][1]
    # alpha=0.8
    # bbox_props = dict(boxstyle="square", fc="lightgrey", ec="gray", lw=1.5, alpha=alpha)
    # ax.text(0.95, 0.05, r"$\rho$: " + "{:.3f}".format(corr), ha="right", va="bottom", size=plt.rcParams["font.size"], 
    #         weight = 'normal', bbox=bbox_props, transform=ax.transAxes)


    # gaps = plt.bar(bin_centers, gap, bottom=bin_corrects, color=[1, 0.7, 0.7], alpha=0.5, width=width, hatch='//', edgecolor='r')
    # errorbars = plt.errorbar(bin_centers, bin_corrects, fmt=".", yerr=bin_mean_2stds)
    
    # plt.plot([0, 1], [0, 1], '--', color='gray')
    # plt.legend([confs, gaps, errorbars], ['Accuracy', 'Gap', '2 stds of the confidence'], loc='upper left', fontsize='x-large')


    # # Clean up
    # bbox_props = dict(boxstyle="square", fc="lightgrey", ec="gray", lw=1.5)
    # plt.text(0.17, 0.82, "ECE: {:.4f}".format(ece), ha="center", va="center", size=20, weight = 'normal', bbox=bbox_props)

    # plt.title("Reliability Diagram; TODO note that stds of the conf are on the acc axis", size=22)
    # plt.ylabel(r'$\vert$Conf - Acc$\vert$',  size=18)
    # plt.xlabel("2 stds of conf mean",  size=18)
    # plt.xlim(0,1)
    # plt.ylim(0,1)
    # plt.show()

In [None]:
def plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, figuretitle="", justPlotTop=False, addXLabel=False):
    if justPlotTop:
        figsize =  (4, 2.2) if addXLabel else (4, 2)
        fig, axs = plt.subplots(1, 2)
        axs = [axs]
    else:
        # TODO try size:
        # figsize = (plt.rcParams["figure.figsize"][1], plt.rcParams["figure.figsize"][1])
        # figsize = (plt.rcParams["figure.figsize"][0], plt.rcParams["figure.figsize"][0])
        figsize = (4, 3.8)
        fig, axs = plt.subplots(2, 2)
    fig.set_size_inches(figsize)

    for ax_j, (DISTRIBUTIONS_DIRECTORY, distribution_name) in enumerate(zip(distribution_directories, distribution_names)):
        if justPlotTop and ax_j >= 1:
                pass
        else:
            for ax_i, DATASET in enumerate(dataset_names):
                print("### ", DISTRIBUTIONS_DIRECTORY, " - ", DATASET)
                y_true = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_true_" + DATASET + ".pt"))
                # y_prob = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_prob_" + DATASET + ".pt"))

                f_mu = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_mu_" + DATASET + ".pt"))
                f_var = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_var_" + DATASET + ".pt"))

                confs, preds, variances, y_pred = calculate_confs_preds_variances_ypreds(f_mu, f_var, y_true)

                ax = axs[ax_j][ax_i]
                plot_conf_acc_gap_by_uncertainty_to_ax(ax, y_pred, variances, y_true, n_bins=100)

                if ax_j == 0:
                    DatasetTranslationDict = {'camelyon17-id': 'WILDS-Camelyon17 (ID)', 'camelyon17-ood': 'WILDS-Camelyon17 (OOD)', 'amazon-id': 'WILDS-Amazon (ID)', 'amazon-ood': 'WILDS-Amazon (OOD)', 'SkinLesions-id': 'SkinLesions (ID)', 'SkinLesions-ood': 'SkinLesions (OOD)'}
                    ax.set_title(DatasetTranslationDict[DATASET])
                if ax_i == 0:
                    if justPlotTop:
                        ax.set_ylabel(r'$\vert$Conf - Acc$\vert$')
                    else:
                        ax.set_ylabel(distribution_name + "\n" + r'$\vert$Conf - Acc$\vert$')
                if ax_j == 1 or addXLabel:
                    ax.set_xlabel(r"$2 \sigma$ of confidence mean")
    fig.suptitle(figuretitle)
    # plt.show()


In [None]:
savedir_confAccGapByUncertainty = "./results/img/Results/ConfAccGapByUncertainty"
if not os.path.exists(savedir_confAccGapByUncertainty):
    os.makedirs(savedir_confAccGapByUncertainty)

In [None]:
distribution_directories = ['./results/predictive_distributions/camelyon17/', # TODO or use:
                            # './results/predictive_distributions/camelyon17_ts/', #TODO or use this one?
                            './results/predictive_distributions/camelyon17_ts_and_scaling_fitted']
distribution_names = ["LLLA", "LLLA+WITS+CVS"]
dataset_names = ['camelyon17-id', 'camelyon17-ood']

# plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, figuretitle="Conf-Acc Gap by Uncertainty for DenseNet121 on Camelyon17")
plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names)
plt.savefig(os.path.join(savedir_confAccGapByUncertainty, "camelyon17.pdf"))
plt.show()

# plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, justPlotTop=True, figuretitle="Conf-Acc Gap by Uncertainty for DenseNet121 on Camelyon17")
plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, justPlotTop=True)
plt.savefig(os.path.join(savedir_confAccGapByUncertainty, "camelyon17_TOPHALF.pdf"))
plt.show()

In [None]:
# distribution_directories = [#'./results/predictive_distributions/camelyon17/', # TODO or use:
#                             './results/predictive_distributions/camelyon17_ts/', #TODO or use this one?
#                             './results/predictive_distributions/camelyon17_ts_and_scaling_fitted']
# distribution_names = ["LLLA + TS(WI)", "LLLA + TS(WI) + Cov-scaling"]
# dataset_names = ['camelyon17-id', 'camelyon17-ood']

# plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, figuretitle="Conf-Acc Gap by Uncertainty for DenseNet121 on Camelyon17")
# plt.show()



In [None]:
distribution_directories = ['./results/predictive_distributions/camelyon17_model6/', # TODO or use:
                            # './results/predictive_distributions/camelyon17_ts/', #TODO or use this one?
                            './results/predictive_distributions/camelyon17_ts_and_scaling_fitted_model6']
distribution_names = ["LLLA", "LLLA+WITS+CVS"]
dataset_names = ['camelyon17-id', 'camelyon17-ood']

plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, figuretitle="Conf-Acc Gap by Uncertainty for DenseNet121 on Camelyon17")
plt.savefig(os.path.join(savedir_confAccGapByUncertainty, "camelyon17_model6.pdf"))
plt.show()

# plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, justPlotTop=True, figuretitle="Conf-Acc Gap by Uncertainty for DenseNet121 on Camelyon17")
plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, justPlotTop=True)
plt.savefig(os.path.join(savedir_confAccGapByUncertainty, "camelyon17_model6_TOPHALF.pdf"))
plt.show()


# Resnet 50

In [None]:
# distribution_directories = ['./results/predictive_distributions/camelyon17_resnet50/', # TODO or use:
#                             # './results/predictive_distributions/camelyon17_ts/', #TODO or use this one?
#                             './results/predictive_distributions/camelyon17_resnet50_ts_and_scaling_fitted']
# distribution_names = ["LLLA", "LLLA+WITS+CVS"]
# dataset_names = ['camelyon17-id', 'camelyon17-ood']

# # plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, justPlotTop=True, figuretitle="Conf-Acc Gap by Uncertainty for ResNet50 on Camelyon17")
# plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, justPlotTop=True)
# plt.savefig(os.path.join(savedir_confAccGapByUncertainty, "camelyon17_ResNet.pdf"))
# plt.show()


# Amazon

In [None]:
distribution_directories = ['./results/predictive_distributions/amazon/',
                            # './results/predictive_distributions/amazon_scaling/',
                            # './results/predictive_distributions/amazon_ts/',
                            './results/predictive_distributions/amazon_ts_and_scaling_fitted/',
]
distribution_names = ["LLLA", "LLLA+WITS+CVS"]
dataset_names = ['amazon-id', 'amazon-ood']

# plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, figuretitle="Conf-Acc Gap by Uncertainty for DistilBERT on Amazon")
# plt.savefig(os.path.join(savedir_confAccGapByUncertainty, "amazon.pdf"))
# plt.show()

# plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, justPlotTop=True, figuretitle="Conf-Acc Gap by Uncertainty for DistilBERT on Amazon")
plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, justPlotTop=True)
plt.savefig(os.path.join(savedir_confAccGapByUncertainty, "amazon_TOPHALF.pdf"))
plt.show()


In [None]:
# distribution_directories = [# './results/predictive_distributions/amazon/',
#                             # './results/predictive_distributions/amazon_scaling/',
#                             './results/predictive_distributions/amazon_ts/',
#                             './results/predictive_distributions/amazon_ts_and_scaling_fitted/',
# ]
# distribution_names = ["LLLA+WITS", "LLLA+WITS+CVS"]
# dataset_names = ['amazon-id', 'amazon-ood']

# plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names)
# plt.show()


# SkinLesions

In [None]:
distribution_directories = ['./results/predictive_distributions/SkinLesions/', # TODO or use:
                            # './results/predictive_distributions/SkinLesions_ts/', #TODO or use this one?
                            './results/predictive_distributions/SkinLesions_ts_and_scaling_fitted']

distribution_names = ["LLLA", "LLLA+WITS+CVS"]
dataset_names = ['SkinLesions-id', 'SkinLesions-ood']

# plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, figuretitle="Conf-Acc Gap by Uncertainty for ResNet50 on SkinLesions")
# plt.savefig(os.path.join(savedir_confAccGapByUncertainty, "skinlesions.pdf"))
# plt.show()

# plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, justPlotTop=True, figuretitle="Conf-Acc Gap by Uncertainty for ResNet50 on SkinLesions")
plot_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, dataset_names, justPlotTop=True, addXLabel=True)
plt.savefig(os.path.join(savedir_confAccGapByUncertainty, "skinlesions_TOPHALF.pdf"))
plt.show()

# Toy Datasets

In [None]:
def plot_ToyDatasets_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, LeftToRight=True, figuretitle=''):
    # LeftToRight allows to plot either Left to right (for view on Beamer) or if it is False, top to bottom (for pages)
    # fig, axs = plt.subplots(3, 13)
    if LeftToRight:
        fig, axs = plt.subplots(3, 6)
        # figsize = (plt.rcParams["figure.figsize"][0], plt.rcParams["figure.figsize"][0])
        figsize = (8, 4.4)
    else:
        raise NotImplementedError
        # fig, axs = plt.subplots(6, 3)
        # # figsize = (plt.rcParams["figure.figsize"][0], plt.rcParams["figure.figsize"][0] * 2)
        # figsize = (4, 8)

    fig.set_size_inches(figsize)
    # fig.set_size_inches(50, 10)

    # TODO restrict figure to 6 ID/OOD conditions
    for ax_jj, (DISTRIBUTIONS_DIRECTORY, distribution_name) in enumerate(zip(distribution_directories, distribution_names)):
        ids = os.listdir(DISTRIBUTIONS_DIRECTORY)
        ids = [i[7:-3] for i in ids if "y_true_" in i]
        try:
            ids = [int(i) for i in ids]
            ids.sort()
            ids = [str(i) for i in ids]
        except:
            pass
        DATASET_LIST = ids
        for ax_ii, DATASET in enumerate(DATASET_LIST):
            if ax_ii >= 6:
                pass
            else:
                if not LeftToRight:
                    ax_i, ax_j = ax_jj, ax_ii # If plotting top to bottom, switch axis
                else:
                    ax_i, ax_j = ax_ii, ax_jj
                print("### ", DISTRIBUTIONS_DIRECTORY, " - ", DATASET)
                y_true = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_true_" + DATASET + ".pt"))
                # y_prob = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_prob_" + DATASET + ".pt"))

                f_mu = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_mu_" + DATASET + ".pt"))
                f_var = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_var_" + DATASET + ".pt"))


                confs, preds, variances, y_pred = calculate_confs_preds_variances_ypreds(f_mu, f_var, y_true)

                ax = axs[ax_j][ax_i]
                plot_conf_acc_gap_by_uncertainty_to_ax(ax, y_pred, variances, y_true, n_bins=100, marker='.')
                if LeftToRight:
                    DATASET_STRING = "0 (ID)" if DATASET == "0" else DATASET
                    if ax_j == 0 or ax_j == 1:
                        ax.set_title(f'Rotation: {DATASET_STRING}')
                    if ax_j == 2:
                        ax.set_title(f'Corruption: {DATASET_STRING}')
                    if ax_i == 0:
                        ax.set_ylabel(distribution_name + "\n" + r'$\vert$Conf - Acc$\vert$')
                    if ax_j == 2:
                        ax.set_xlabel(r"$2 \sigma$ of conf mean")
                else:
                    pass # TODO
                    # if ax_j == 0:
                    #     ax.set_title(distribution_name)
                    # if ax_i == 0:
                    #     ax.set_ylabel(r'$\vert$Conf - Acc$\vert$')
    fig.suptitle(figuretitle)
    # plt.savefig("plotting_uncertainty_pr/acc_gap_by_uncertainty/toy_datasets.png")
    # plt.show()


In [None]:
# # TODO remove testing
# distribution_directories = ['./results/predictive_distributions/R-MNIST',
#                             # './results/predictive_distributions/R-MNIST_scaling',
#                             # './results/predictive_distributions/R-MNIST_ts',
#                             # './results/predictive_distributions/R-MNIST_ts_scaling',
#                             './results/predictive_distributions/R-FMNIST',
#                             # './results/predictive_distributions/R-FMNIST_scaling',
#                             # './results/predictive_distributions/R-FMNIST_ts',
#                             # './results/predictive_distributions/R-FMNIST_ts_scaling',
#                             './results/predictive_distributions/CIFAR-10-C',
#                             # './results/predictive_distributions/CIFAR-10-C_scaling',
#                             # './results/predictive_distributions/CIFAR-10-C_ts',
#                             # './results/predictive_distributions/CIFAR-10-C_ts_scaling',
# ]

# distribution_names = ["R-MNIST", "R-FMNIST", "CIFAR-10-C"]


# plot_ToyDatasets_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, LeftToRight=True, figuretitle='Conf-Acc Gap by Uncertainty\nfor R-MNIST, R-FMNIST and CIFAR-10-C')
# plt.savefig(os.path.join(savedir_confAccGapByUncertainty, "toy_datasets.pdf"))
# plt.show()


In [None]:
distribution_directories = ['./results/predictive_distributions/R-MNIST',
                            # './results/predictive_distributions/R-MNIST_scaling',
                            # './results/predictive_distributions/R-MNIST_ts',
                            # './results/predictive_distributions/R-MNIST_ts_scaling',
                            './results/predictive_distributions/R-FMNIST',
                            # './results/predictive_distributions/R-FMNIST_scaling',
                            # './results/predictive_distributions/R-FMNIST_ts',
                            # './results/predictive_distributions/R-FMNIST_ts_scaling',
                            './results/predictive_distributions/CIFAR-10-C',
                            # './results/predictive_distributions/CIFAR-10-C_scaling',
                            # './results/predictive_distributions/CIFAR-10-C_ts',
                            # './results/predictive_distributions/CIFAR-10-C_ts_scaling',
]

distribution_names = ["R-MNIST", "R-FMNIST", "CIFAR-10-C"]


# plot_ToyDatasets_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, figuretitle='Conf-Acc Gap by Uncertainty for R-MNIST, R-FMNIST and CIFAR-10-C')
plot_ToyDatasets_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names)
plt.savefig(os.path.join(savedir_confAccGapByUncertainty, "toy_datasets.pdf"))
plt.show()


In [None]:
distribution_directories = [#'./results/predictive_distributions/R-MNIST',
                            './results/predictive_distributions/R-MNIST_scaling',
                            # './results/predictive_distributions/R-MNIST_ts',
                            # './results/predictive_distributions/R-MNIST_ts_scaling',
                            # './results/predictive_distributions/R-FMNIST',
                            './results/predictive_distributions/R-FMNIST_scaling',
                            # './results/predictive_distributions/R-FMNIST_ts',
                            # './results/predictive_distributions/R-FMNIST_ts_scaling',
                            # './results/predictive_distributions/CIFAR-10-C',
                            './results/predictive_distributions/CIFAR-10-C_scaling',
                            # './results/predictive_distributions/CIFAR-10-C_ts',
                            # './results/predictive_distributions/CIFAR-10-C_ts_scaling',
]

distribution_names = ["R-MNIST", "R-FMNIST", "CIFAR-10-C"]


# plot_ToyDatasets_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, figuretitle='Conf-Acc Gap by Uncertainty for R-MNIST, R-FMNIST and CIFAR-10-C using CVS')
plot_ToyDatasets_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names)
plt.savefig(os.path.join(savedir_confAccGapByUncertainty, "toy_datasets_CVS.pdf"))
plt.show()


In [None]:
distribution_directories = [#'./results/predictive_distributions/R-MNIST',
                            # './results/predictive_distributions/R-MNIST_scaling',
                            './results/predictive_distributions/R-MNIST_ts',
                            # './results/predictive_distributions/R-MNIST_ts_scaling',
                            # './results/predictive_distributions/R-FMNIST',
                            # './results/predictive_distributions/R-FMNIST_scaling',
                            './results/predictive_distributions/R-FMNIST_ts',
                            # './results/predictive_distributions/R-FMNIST_ts_scaling',
                            # './results/predictive_distributions/CIFAR-10-C',
                            # './results/predictive_distributions/CIFAR-10-C_scaling',
                            './results/predictive_distributions/CIFAR-10-C_ts',
                            # './results/predictive_distributions/CIFAR-10-C_ts_scaling',
]

distribution_names = ["R-MNIST", "R-FMNIST", "CIFAR-10-C"]

# plot_ToyDatasets_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, figuretitle='Conf-Acc Gap by Uncertainty for R-MNIST, R-FMNIST and CIFAR-10-C using WITS')
plot_ToyDatasets_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names)
plt.savefig(os.path.join(savedir_confAccGapByUncertainty, "toy_datasets_WITS.pdf"))
plt.show()



In [None]:
distribution_directories = [#'./results/predictive_distributions/R-MNIST',
                            # './results/predictive_distributions/R-MNIST_scaling',
                            # './results/predictive_distributions/R-MNIST_ts',
                            './results/predictive_distributions/R-MNIST_ts_scaling',
                            # './results/predictive_distributions/R-FMNIST',
                            # './results/predictive_distributions/R-FMNIST_scaling',
                            # './results/predictive_distributions/R-FMNIST_ts',
                            './results/predictive_distributions/R-FMNIST_ts_scaling',
                            # './results/predictive_distributions/CIFAR-10-C',
                            # './results/predictive_distributions/CIFAR-10-C_scaling',
                            # './results/predictive_distributions/CIFAR-10-C_ts',
                            './results/predictive_distributions/CIFAR-10-C_ts_scaling',
]

distribution_names = ["R-MNIST", "R-FMNIST", "CIFAR-10-C"]

# plot_ToyDatasets_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names, figuretitle='Conf-Acc Gap by Uncertainty for R-MNIST, R-FMNIST and CIFAR-10-C using WITS+CVS')
plot_ToyDatasets_conf_acc_gap_by_uncertainty_plot(distribution_directories, distribution_names)
plt.savefig(os.path.join(savedir_confAccGapByUncertainty, "toy_datasets_WITS_CVS.pdf"))
plt.show()

