In [None]:
import os
os.getcwd()

In [None]:
import argparse
import yaml

import torch
import pycalib
from laplace import Laplace

import utils.data_utils as du
import utils.wilds_utils as wu
import utils.utils as util
from utils.test import test
from marglik_training.train_marglik import get_backend

# import warnings
# warnings.filterwarnings('ignore')

from argparse import Namespace

from tqdm import tqdm

import matplotlib.pyplot as plt

from copy import deepcopy

from random import randint

import numpy as np

In [None]:
from tueplots import bundles



# Inspired by bundles.neurips2023(), but adapting font sizes for pt12 standard

settings_dict = {'text.usetex': True,
                 'font.family': 'serif',
                 'text.latex.preamble': '\\renewcommand{\\rmdefault}{ptm}\\renewcommand{\\sfdefault}{phv}',
                 'figure.figsize': (5.5, 3.399186938124422),
                 'figure.constrained_layout.use': True,
                 'figure.autolayout': False,
                 'savefig.bbox': 'tight',
                 'savefig.pad_inches': 0.015,
                 'font.size': 10,
                 'axes.labelsize': 10,
                 'legend.fontsize': 8,
                 'xtick.labelsize': 8,
                 'ytick.labelsize': 8,
                 'axes.titlesize': 10,
                 'figure.dpi': 300}


plt.rcParams.update(settings_dict)





# Can use colors from bundles.rgb.
#     tue_blue
#     tue_brown
#     tue_dark
#     tue_darkblue
#     tue_darkgreen
#     tue_gold
#     tue_gray
#     tue_green
#     tue_lightblue
#     tue_lightgold
#     tue_lightgreen
#     tue_lightorange
#     tue_mauve
#     tue_ocre
#     tue_orange
#     tue_red
#     tue_violet

In [None]:
from torchvision import transforms

def invImageNetNorm(x):
    """ Inverts the Normalization given by:
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]) """
    invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                     std = [ 1., 1., 1. ]),
                               ])

    return invTrans(x)

In [None]:
def batch_cov(points):
    B, N, D = points.size()
    mean = points.mean(dim=1).unsqueeze(1)
    diffs = (points - mean).reshape(B * N, D)
    prods = torch.bmm(diffs.unsqueeze(2), diffs.unsqueeze(1)).reshape(B, N, D, D)
    bcov = prods.sum(dim=1) / (N - 1)  # Unbiased estimate
    return bcov  # (B, D, D)


In [None]:
def normal_samples(mean, var, n_samples, generator=None):
    """Produce samples from a batch of Normal distributions either parameterized
    by a diagonal or full covariance given by `var`.

    Parameters
    ----------
    mean : torch.Tensor
        `(batch_size, output_dim)`
    var : torch.Tensor
        (co)variance of the Normal distribution
        `(batch_size, output_dim, output_dim)` or `(batch_size, output_dim)`
    generator : torch.Generator
        random number generator
    """
    assert mean.ndim == 2, 'Invalid input shape of mean, should be 2-dimensional.'
    _, output_dim = mean.shape
    randn_samples = torch.randn((output_dim, n_samples), device=mean.device, 
                                dtype=mean.dtype, generator=generator)
    
    if mean.shape == var.shape:
        # diagonal covariance
        scaled_samples = var.sqrt().unsqueeze(-1) * randn_samples.unsqueeze(0)
        return (mean.unsqueeze(-1) + scaled_samples).permute((2, 0, 1))
    elif mean.shape == var.shape[:2] and var.shape[-1] == mean.shape[1]:
        # full covariance
        scale = torch.linalg.cholesky(var)
        scaled_samples = torch.matmul(scale, randn_samples.unsqueeze(0))  # expand batch dim
        return (mean.unsqueeze(-1) + scaled_samples).permute((2, 0, 1))
    else:
        raise ValueError('Invalid input shapes.')



In [None]:
def get_appropriate_testloader(dataset):
    if dataset == 'camelyon17-id':
        dataset = 'camelyon17'
        train_loader, val_loader, in_test_loader = wu.get_wilds_loaders(
            dataset, './data', 1.0, 1, download=False, use_ood_val_set=False)
        test_loader = in_test_loader
    elif dataset == 'camelyon17-ood':
        dataset = 'camelyon17'
        test_loader = wu.get_wilds_ood_test_loader(
            dataset, './data', 1.0)
    elif dataset == 'SkinLesions-id':
        train_loader, val_loader, test_loader = du.get_ham10000_loaders('./data', batch_size=16, train_batch_size=16, num_workers=4, image_size=512)
    elif dataset == 'SkinLesions-ood':
        test_loader = du.get_SkinLesions_ood_loader(None, data_path='./data', batch_size=16, num_workers=4, image_size=512)
    return test_loader


In [None]:
def check_testloader_order(test_loader, y_true):
    ''' Check, that the labels produced by the test_loader match the order of y_true. 
        This should assert, that the predictions used here do in fact match the images produced by the test_loader
    '''
    y_true_testloader = []

    for _, y in test_loader:
        y_true_testloader.append(y)
    y_true_testloader = torch.cat(y_true_testloader)
    assert torch.all(y_true_testloader == y_true)

In [None]:
def calculate_confs_preds_variances(f_mu, f_var, y_true, n_samples = 10000, generator = None, batchsize = 128):
    # For all images, calculate the conf and covariance
    # To do this sample from distribution

    confs_list = []
    preds_list = []
    variances_list = []

    s_list = list(range(0, y_true.shape[0] + batchsize, batchsize))
    # s_list = list(range(0, 1000, batchsize))
    for start, stop in tqdm(zip(s_list[:-1], s_list[1:])):
        f_mu_now = f_mu[start:stop]
        f_var_now = f_var[start:stop]
        
        f_samples = normal_samples(f_mu_now, f_var_now, n_samples, generator)
        y_prob = torch.softmax(f_samples, dim=-1)

        covariances = batch_cov(y_prob.permute(1,0,2))

        y_pred = y_prob.mean(dim=0)

        confs, preds = torch.max(y_pred, 1)

        variances = torch.tensor([c[preds[i], preds[i]] for i, c in enumerate(covariances)])

        confs_list.append(confs)
        preds_list.append(preds)
        variances_list.append(variances)

    confs_list = torch.cat(confs_list)
    preds_list = torch.cat(preds_list)
    variances_list = torch.cat(variances_list)

    return confs_list, preds_list, variances_list

In [None]:
def plot_conf_variance(confs, variances, title_string = '', alpha = 0.05):
    fig, ax = plt.subplots()
    ax.scatter(confs.numpy(), variances.numpy(), alpha=alpha)
    ax.set_xlabel("confidence")
    ax.set_ylabel("variances")
    plt.title(title_string)
    plt.show()

In [None]:
# # DISTRIBUTIONS_DIRECTORY = './results/predictive_distributions/camelyon17_model6/'
# # DISTRIBUTIONS_DIRECTORY = './results/predictive_distributions/camelyon17_resnet50/'
# DISTRIBUTIONS_DIRECTORY = './results/predictive_distributions/camelyon17_wrn50/'

# # DISTRIBUTIONS_DIRECTORY = "/mnt/j/Results_Predictive_Distributions/camelyon17_ts_and_scaling_fitted"


# # DATASET = 'camelyon17-id' # 'camelyon17-ood'
# DATASET = 'camelyon17-ood'

# # images = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "x_" + DATASET + ".pt"))
# y_true = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_true_" + DATASET + ".pt"))
# # y_prob = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_prob_" + DATASET + ".pt"))

# f_mu = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_mu_" + DATASET + ".pt"))
# f_var = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_var_" + DATASET + ".pt"))


In [None]:
# test_loader = get_appropriate_testloader(DATASET)


In [None]:
# check_testloader_order(test_loader, y_true)


In [None]:
# confs, preds, variances = calculate_confs_preds_variances(f_mu, f_var, y_true)


In [None]:
# plot_conf_variance(confs, variances)

# Same with variance from logits:

In [None]:
# logit_variances = torch.tensor([c[preds[i], preds[i]] for i, c in enumerate(f_var)])


In [None]:
# fig, ax = plt.subplots()
# ax.scatter(confs.numpy(), logit_variances.numpy(), alpha=0.05)
# ax.set_xlabel("confs")
# ax.set_ylabel("variances")
# plt.show()

# Plot images into plot

In [None]:
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib.pyplot as plt
import matplotlib.path as mpath
import matplotlib.patches as mpatches

import matplotlib
# matplotlib.rcParams["figure.dpi"] = 300
# from matplotlib import rc
# rc('text', usetex=True)
matplotlib.rc('text', usetex=True)


In [None]:
from PIL import Image, ImageDraw
def add_cross_to_image(image, color=None, width=1):
    ''' Add a cross to the center of an image'''
    img = Image.fromarray(np.uint8(image*255))
    img1 = ImageDraw.Draw(img)
    img1.line([(0, 0), (img.width, img.height)], width=width, fill=color)
    img1.line([(img.width, 0), (0, img.height)], width=width, fill=color)
    return img

In [None]:
# # NAIVE: just produce an images array as it previously was:
# images = []

# for x, _ in test_loader:
#     images.append(x)
# images = torch.cat(images)
# images.shape


In [None]:
# check that loader.dataset has the same order:

def check_loader_dataset_order(test_loader, y_true):
    ''' Test, whether the labels produced by test_loader.dataset[i] have the same order as the ones given by y_true
        This should assert, that the images do in fact match the predictions that have been made
    '''
    labels = []
    if len(test_loader.dataset[0]) == 3:
        for i in range(len(y_true)):
            x, y, _ = test_loader.dataset[i]
            labels.append(y.reshape((1)))
    elif len(test_loader.dataset[0]) == 2:
        for i in range(len(y_true)):
            x, y = test_loader.dataset[i]
            y = torch.asarray(y)
            labels.append(y.reshape((1)))


    labels = torch.cat(labels)

    assert torch.all(labels == y_true)

In [None]:
# check_loader_dataset_order(test_loader, y_true)

In [None]:
def get_images(sample_ids, loader):
    ''' Return an array with the images specified by sample_ids '''
    images = []
    for id in sample_ids:
        id = int(id)
        x = loader.dataset[id][0]
        images.append(x.unsqueeze(0))
    images = torch.cat(images)
    return images
        

In [None]:
def find_image_ids_gridfilling(x, y, IMAGES_SUBSET):

    # TODO normalize
    # As grid axis can be very different, we need to normalize:
    # make x_min and x_max to [0,1] and y_min, y_max to [0, 1]
    # sbutract x_min from all. Then divide by new maximum

    x_limits = [x.min(), x.max()]
    y_limits = [y.min(), y.max()]

    x = x - x.min()
    x = x / x.max()

    y = y - y.min()
    y = y / y.max()


    # Since the plots are not quadratic, we need to put the points on the grid
    # in such a way that they are equally spaced
    # (5.5, 3.399186938124422)
    proportion_XY = 5.5 / 3.399186938124422
    num_images_y = np.sqrt((IMAGES_SUBSET / proportion_XY))
    num_images_x = num_images_y * proportion_XY

    x_grid = torch.linspace(x.min(), x.max(), int(num_images_x))
    y_grid = torch.linspace(y.min(), y.max(), int(num_images_y))
    xv, yv = torch.meshgrid(x_grid, y_grid)

    # Testing
    # plt.scatter(xv, yv)
    # plt.show()

    xv, yv = xv.flatten(), yv.flatten()

    gridpoints = torch.vstack([xv, yv]).transpose(0, 1)
    xy = torch.vstack([x, y]).transpose(0, 1)
    gridpoints, xy = gridpoints.unsqueeze(0), xy.unsqueeze(0)
    dists, ids = torch.cdist(xy, gridpoints, p=2).squeeze().min(0)


    # # Testing
    # gridpoints = gridpoints.squeeze()
    # xy = xy.squeeze()
    # for gridpoint, xy_point in zip(gridpoints, xy[ids]):
    #     plt.plot([gridpoint[0], xy_point[0]], [gridpoint[1], xy_point[1]])
    # plt.show()

    # # Testing
    # plt.scatter(x, y, color='red')
    # plt.scatter(x[ids], y[ids])
    # plt.show()
    
    return torch.unique(ids), x_limits, y_limits 


# TESTING
# x = confs
# y = variances

# TESTING
# x = torch.asarray([1, 5, 1.5, 3.5, 1.5, 2.5])
# y = torch.asarray([1, 5, 1.5, 3.5, 4.5, 4.5])

# y = logit_variances
# ALL_SAMPLE_IDS = torch.tensor(list(range(len(y_true))))
# sample_ids = ALL_SAMPLE_IDS
# sample_ids = find_image_ids_gridfilling(x, y, IMAGES_SUBSET)




In [None]:

class CrossedSquarePatch(object):
    def __init__(self, color = 'red', crossed = False):
        self.color = color
        self.crossed = crossed

class CrossedSquarePatchHandler(object):
    def legend_artist(self, legend, orig_handle, fontsize, handlebox):
        x0, y0 = handlebox.xdescent, handlebox.ydescent
        width, height = handlebox.width, handlebox.height
        patch = mpatches.Rectangle([x0, y0], width, height, fill=None,
                                   edgecolor=orig_handle.color,
                                   transform=handlebox.get_transform())
        handlebox.add_artist(patch)

        if orig_handle.crossed:
            Path = mpath.Path
            path_data = [
                (Path.MOVETO, (x0, y0)),
                (Path.LINETO, (width, height)),
                (Path.MOVETO, (width, y0)),
                (Path.LINETO, (x0, height)),
                ]
            codes, verts = zip(*path_data)
            path = mpath.Path(verts, codes)
            path_patch = mpatches.PathPatch(path, edgecolor=orig_handle.color, fill=None, alpha=1)
            handlebox.add_artist(path_patch)


In [None]:
def plot_conf_var_gridfilling(confs, variances, test_loader, preds, y_true, title_string = '', images_subset = 200, dataset = 'camelyon17'):
    x = confs
    y = variances
    sample_ids, x_limits, y_limits = find_image_ids_gridfilling(x, y, images_subset)

    print("x_limits: ", x_limits)
    print("y_limits: ", y_limits)
    

    x = x.numpy()[sample_ids]
    y = y.numpy()[sample_ids]
    imgs = get_images(sample_ids, test_loader)

    fig, ax = plt.subplots()
    # fig.set_size_inches(10, 10)
    # ax.set_xlim(x_limits)
    # ax.set_ylim(y_limits)

    ax.scatter(x, y) 
    for x0, y0, img, sample_id in zip(x, y, imgs, sample_ids):
        img = invImageNetNorm(img).permute(1,2,0)

        # green border for 0, red border for 1
        # x if it is wrongly classified
        if dataset == 'camelyon17':
            # my_green = matplotlib.colors.to_hex(tuple(bundles.rgb.tue_green))
            # my_red = matplotlib.colors.to_hex(tuple(bundles.rgb.tue_red))
            my_green, my_red = 'green', '#CB0000'
            class_to_color = {0: my_green, 1: my_red}
            zoom = 0.22
        elif dataset == 'SkinLesions':
            class_to_color = {0: 'green', 1: 'red', 2: 'blue', 3: 'orange', 4: 'pink', 5: 'yellow', 6: 'violet', 7: 'lightblue'}
            zoom = 0.045

        true_class = int(y_true[sample_id])

        BorderWidth = 2
        LineWidth = 6
        if dataset == 'camelyon17':
            BorderWidth = 2
            LineWidth = 5
        if dataset == 'SkinLesions':
            BorderWidth = 2
            LineWidth = 20

        bboxprops = dict(edgecolor=class_to_color[true_class], linewidth=BorderWidth)
        if preds[sample_id] != y_true[sample_id]:
            img = add_cross_to_image(img, color=class_to_color[true_class], width=LineWidth)
        
        img = OffsetImage(img, zoom=zoom)

        ab = AnnotationBbox(img, (x0, y0), frameon=True, pad=0, bboxprops=bboxprops)
        ax.add_artist(ab)
    plt.xlabel('confidence')
    # plt.ylabel(r'$\leftarrow$ aleatoric uncertainty $\leftarrow$ ~~~~~~~~~~ \textbf{variance} ~~~~~~~~~~ $\rightarrow$ epistemic uncertatinty $\rightarrow$')
    plt.ylabel(r'$\leftarrow$ aleatoric $\leftarrow$ ~~~ \textbf{variance} ~~~ $\rightarrow$ epistemic $\rightarrow$')

    if dataset == "camelyon17":
        plt.legend([CrossedSquarePatch(my_green, False),
                    CrossedSquarePatch(my_red, False),
                    CrossedSquarePatch(my_green, True),
                    CrossedSquarePatch(my_red, True)], 
                    # ['Class = 0 - No Tumor; Correctly Classified',
                    # 'Class = 1 - Tumor; Correctly Classified',
                    # 'true Class = 0; Predicted = 1; Misclassified',
                    # 'true Class = 1; Predicted = 0; Misclassified'],
                    ['Normal; Correctly Classified',
                    'Tumor; Correctly Classified',
                    'Normal; Misclassified as Tumor',
                    'Tumor; Misclassified as Normal'],
                handler_map={CrossedSquarePatch: CrossedSquarePatchHandler()}, handlelength = 0.7, handleheight = 0.7, fontsize = 'large')
    if dataset == "SkinLesions":
        plt.legend([CrossedSquarePatch('black', False),
                    CrossedSquarePatch('black', True)],
                    ['Correctly Classified',
                    'Misclassified'],
                handler_map={CrossedSquarePatch: CrossedSquarePatchHandler()}, handlelength = 0.7, handleheight = 0.7, fontsize = 'large')

    plt.title(title_string)


In [None]:
# plot_conf_var_gridfilling(confs, variances, title_string = '', images_subset = 200)

In [None]:
distribution_directories = ['./results/predictive_distributions/camelyon17_model6/',
                            './results/predictive_distributions/camelyon17_resnet50/',
                            './results/predictive_distributions/camelyon17_wrn50/']

model_strings = ['densenet121', 'resnet50', 'wrn50']

savedir = './results/images/img/ConfVariancePlots'
if not os.path.exists(savedir):
    os.makedirs(savedir)

for distribution_directory, model_string in zip(distribution_directories, model_strings):
    for dataset in ['camelyon17-id', 'camelyon17-ood']:
        y_true = torch.load(os.path.join(distribution_directory, "y_true_" + dataset + ".pt"))
        f_mu = torch.load(os.path.join(distribution_directory, "f_mu_" + dataset + ".pt"))
        f_var = torch.load(os.path.join(distribution_directory, "f_var_" + dataset + ".pt"))

        test_loader = get_appropriate_testloader(dataset)
        
        # TODO checks can also be omitted if everything works
        # check_loader_dataset_order(test_loader, y_true)
        # check_testloader_order(test_loader, y_true)

        confs, preds, variances = calculate_confs_preds_variances(f_mu, f_var, y_true)

        ModelStringConversion = {'densenet121': 'DenseNet121', 'resnet50': 'ResNet50', 'wrn50': 'WideResNet50'}
        DatasetStringConversion = {'camelyon17-id': 'Camelyon17 (ID)', 'camelyon17-ood': 'Camelyon17 (OOD)'}
        title_string = f'Uncertainty of {ModelStringConversion[model_string]} on {DatasetStringConversion[dataset]}'

        plot_conf_variance(confs, variances, title_string=title_string)

        plot_conf_var_gridfilling(confs, variances, test_loader, preds, y_true, title_string = title_string, images_subset = 100)

        plt.savefig(os.path.join(savedir, f'{model_string}_{dataset}.pdf'))
        plt.show()




In [None]:
# TODO modify plotting code to not have a plt.show()
# if not os.path.exists('img/Results/ToyData_ECE_NLL/'):
#     os.makedirs('img/Results/ToyData_ECE_NLL/')
# plt.savefig('img/Results/ToyData_ECE_NLL/MAP_LLLA_TS_ECE_NLL.pdf')
# plt.show()


# SkinLesions

In [None]:
distribution_directories = ['./results/predictive_distributions/SkinLesions/',
                            './results/predictive_distributions/SkinLesions_wrn50/']

model_strings = ['resnet50', 'wrn50']

savedir = './results/images/img/ConfVariancePlots'
if not os.path.exists(savedir):
    os.makedirs(savedir)


for distribution_directory, model_string in zip(distribution_directories, model_strings):
    for dataset in ['SkinLesions-id', 'SkinLesions-ood']:
        y_true = torch.load(os.path.join(distribution_directory, "y_true_" + dataset + ".pt"))
        f_mu = torch.load(os.path.join(distribution_directory, "f_mu_" + dataset + ".pt"))
        f_var = torch.load(os.path.join(distribution_directory, "f_var_" + dataset + ".pt"))

        test_loader = get_appropriate_testloader(dataset)
        
        # TODO checks can also be omitted if everything works
        # check_loader_dataset_order(test_loader, y_true)
        # check_testloader_order(test_loader, y_true)

        confs, preds, variances = calculate_confs_preds_variances(f_mu, f_var, y_true)

        ModelStringConversion = {'densenet121': 'DenseNet121', 'resnet50': 'ResNet50', 'wrn50': 'WideResNet50'}
        DatasetStringConversion = {'camelyon17-id': 'Camelyon17 (ID)', 'camelyon17-ood': 'Camelyon17 (OOD)', 'SkinLesions-id': 'SkinLesions (ID)', 'SkinLesions-ood': 'SkinLesions (OOD)'}
        title_string = f'Uncertainty of {ModelStringConversion[model_string]} on {DatasetStringConversion[dataset]}'

        plot_conf_variance(confs, variances, title_string=title_string)

        plot_conf_var_gridfilling(confs, variances, test_loader, preds, y_true, title_string = title_string, images_subset = 100, dataset='SkinLesions')

        plt.savefig(os.path.join(savedir, f'{model_string}_{dataset}.pdf'))
        plt.show()
