In [None]:
import os
os.getcwd()

In [None]:
import argparse
import yaml

import torch
import pycalib
from laplace import Laplace

import utils.data_utils as du
import utils.wilds_utils as wu
import utils.utils as util
from utils.test import test
from marglik_training.train_marglik import get_backend

# import warnings
# warnings.filterwarnings('ignore')

from argparse import Namespace

from tqdm import tqdm

import matplotlib.pyplot as plt

from copy import deepcopy

from random import randint

import numpy as np

In [None]:
from torchvision import transforms

def invImageNetNorm(x):
    """ Inverts the Normalization given by:
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]) """
    invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                     std = [ 1., 1., 1. ]),
                               ])

    return invTrans(x)

In [None]:
# DISTRIBUTIONS_DIRECTORY = './results/predictive_distributions/amazon_vanilla/'
# DISTRIBUTIONS_DIRECTORY = './results/predictive_distributions/amazon_ts_vanilla/'
# DISTRIBUTIONS_DIRECTORY = './results/predictive_distributions/camelyon17_vanilla/'

DISTRIBUTIONS_DIRECTORY = './results/predictive_distributions/camelyon17_ts/'
# DISTRIBUTIONS_DIRECTORY = './results/predictive_distributions/camelyon17_scaling_fitted/'


# DATASET = 'camelyon17-id' # 'camelyon17-ood'
DATASET = 'camelyon17-ood'

x = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "x_" + DATASET + ".pt"))
y_true = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_true_" + DATASET + ".pt"))
y_prob = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "y_prob_" + DATASET + ".pt"))

# f_mu = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_mu_" + DATASET + ".pt"))
f_var = torch.load(os.path.join(DISTRIBUTIONS_DIRECTORY, "f_var_" + DATASET + ".pt"))




In [None]:
y_true.sum()

In [None]:
y_true

In [None]:
y_true.shape

In [None]:
# DATA_SUBSET = 1000

# x = x[:DATA_SUBSET]
# y_true = y_true[:DATA_SUBSET]
# y_prob = y_prob[:, :DATA_SUBSET]


In [None]:
def batch_cov(points):
    B, N, D = points.size()
    mean = points.mean(dim=1).unsqueeze(1)
    diffs = (points - mean).reshape(B * N, D)
    prods = torch.bmm(diffs.unsqueeze(2), diffs.unsqueeze(1)).reshape(B, N, D, D)
    bcov = prods.sum(dim=1) / (N - 1)  # Unbiased estimate
    return bcov  # (B, D, D)



In [None]:
y_prob.shape

In [None]:
# covariances = batch_cov(y_prob.permute(1,0,2))

# To prevent crashing, do it in batches:
s_list = list(range(0, y_prob.shape[1] + 10000, 5000))
covariances = torch.cat([batch_cov(y_prob[:, start:stop].permute(1,0,2)) for start, stop in zip(s_list[:-1], s_list[1:])])
y_pred = y_prob.mean(dim=0)


In [None]:
confs, preds = torch.max(y_pred, 1)
print("conf: ", confs.mean().item())
print("acc: ", (y_true == preds).float().mean().item())


In [None]:
variances = torch.tensor([c[preds[i], preds[i]] for i, c in enumerate(covariances)])

print("mean_variance: ", variances.mean().item())

In [None]:
logit_variances = torch.tensor([c[preds[i], preds[i]] for i, c in enumerate(f_var)])


In [None]:
correctly_classified = y_true == preds

wrongly_classified = torch.logical_not(correctly_classified)

In [None]:
ALL_SAMPLE_IDS = torch.tensor(list(range(len(y_true))))

IDS_CORRECT = ALL_SAMPLE_IDS[y_true == preds]
IDS_WRONG = ALL_SAMPLE_IDS[torch.logical_not(y_true == preds)]

IDS_HIGH_CONFIDENCE = torch.argsort(confs, descending=True)
IDS_LOW_CONFIDENCE = torch.argsort(confs, descending=False)
IDS_MIDDLE_CONFIDENCE = ALL_SAMPLE_IDS[torch.logical_and(confs >= 0.6, confs <= 0.7)]


# IDS_HIGH_VARIANCE = torch.argsort(variances, descending=True)
# IDS_LOW_VARIANCE = torch.argsort(variances, descending=False)

IDS_HIGH_VARIANCE = torch.argsort(logit_variances, descending=True)
IDS_LOW_VARIANCE = torch.argsort(logit_variances, descending=False)

In [None]:
def three_intersection(arg1, arg2, arg3):
    return np.intersect1d(np.intersect1d(arg1, arg2), arg3)

In [None]:
AMOUNT_IN_EACH_CONDITION = 30

condition_dict = {}
for correctness, correctness_name in zip([IDS_CORRECT, IDS_WRONG], ["correct", "wrong"]):
    for confidence, confidence_name in zip([IDS_HIGH_CONFIDENCE, IDS_LOW_CONFIDENCE, IDS_MIDDLE_CONFIDENCE], ["high_conf", "low_conf", "middle_conf"]):
        for variance, variance_name in zip([IDS_HIGH_VARIANCE, IDS_LOW_VARIANCE], ["high_variance", "low_variance"]):
            
            # Increase the number of top IDs considered for the intersection, until there are enough elements in the intersection
            for i in range(1, len(ALL_SAMPLE_IDS), AMOUNT_IN_EACH_CONDITION):
                condition_ids = three_intersection(correctness, confidence[:i], variance[:i])
                if len(condition_ids) >= AMOUNT_IN_EACH_CONDITION:
                    break
            
            condition_string = f'{correctness_name}_{confidence_name}_{variance_name}'
            condition_dict[condition_string] = condition_ids[:AMOUNT_IN_EACH_CONDITION]

In [None]:
# # Examine:
# # Correctly classified, with high confidence
# sort = torch.argsort(confs, descending=True)
# ID_CORRECT_HIGH_CONF = sort[torch.nonzero(correctly_classified[sort])[0]].item()

# # correctly classified with low confidence
# sort = torch.argsort(confs, descending=False)
# ID_CORRECT_LOW_CONF = sort[torch.nonzero(correctly_classified[sort])[0]].item()


# # wrongly classified with high confidence
# sort = torch.argsort(confs, descending=True)
# ID_WRONG_HIGH_CONF = sort[torch.nonzero(wrongly_classified[sort])[0]].item()

# # wrongly classified with low confidence
# sort = torch.argsort(confs, descending=False)
# ID_WRONG_LOW_CONF = sort[torch.nonzero(wrongly_classified[sort])[0]].item()


# # correctly classified, with high variance in the predicted class
# sort = torch.argsort(variances, descending=True)
# ID_CORRECT_HIGH_VARIANCE = sort[torch.nonzero(correctly_classified[sort])[0]].item()


# # wrongly classified, with high variance in the predicted class
# sort = torch.argsort(variances, descending=True)
# ID_WRONG_HIGH_VARIANCE = sort[torch.nonzero(wrongly_classified[sort])[0]].item()


In [None]:
# sample_ids = [ID_CORRECT_HIGH_CONF, ID_CORRECT_LOW_CONF, ID_WRONG_HIGH_CONF, ID_WRONG_LOW_CONF, ID_CORRECT_HIGH_VARIANCE, ID_WRONG_HIGH_VARIANCE]
# sample_names = ["ID_CORRECT_HIGH_CONF", "ID_CORRECT_LOW_CONF", "ID_WRONG_HIGH_CONF", "ID_WRONG_LOW_CONF", "ID_CORRECT_HIGH_VARIANCE", "ID_WRONG_HIGH_VARIANCE"]

In [None]:
# for SAMPLE_ID, SAMPLE_NAME in zip(sample_ids, sample_names):
#     print("sample_name: ", SAMPLE_NAME)
#     print("SAMPLE_ID: ", SAMPLE_ID)
#     print("conf: ", confs[SAMPLE_ID])
#     print("Correct: ", y_true[SAMPLE_ID] == preds[SAMPLE_ID])
#     print("variance: ", variances[SAMPLE_ID])


In [None]:

# TODO: do all combinations of [Correct, Wrong] x [high confidence, low confidence] x [high variance, low variance]

# Observe: 
#   high confidence -> no uncertainty
#   low confidence, high variance ~= epistemic uncertainty (uncertainty is due to the randomness in the model weights)
#   low confidence, low variance ~= aleatoric uncertainty (the model is very certain of being of low confidence, 
#                                           regardless of smal fluctuations in the weights, uncertainty due to true randomness in the training data)


In [None]:
for sample_name, sample_ids  in condition_dict.items():
    sample_id = sample_ids[0].item()
    print("sample_name: ", sample_name)
    print("SAMPLE_ID: ", sample_id)
    print("conf: ", confs[sample_id])
    print("Correct: ", y_true[sample_id] == preds[sample_id])
    print("variance: ", variances[sample_id])


In [None]:
# Just random samples: 
sample_ids = np.random.choice(ALL_SAMPLE_IDS, AMOUNT_IN_EACH_CONDITION,  replace=False)
sample_name = "Random Images"

fig, axs = plt.subplots(AMOUNT_IN_EACH_CONDITION, y_prob.shape[-1] + 1) # number of possible classes
fig.set_size_inches(14, AMOUNT_IN_EACH_CONDITION * 1.2)

for i, sample_id in enumerate(sample_ids):
        
    axs[i][0].imshow(invImageNetNorm(x[sample_id]).permute(1,2,0))
    axs[i][0].set_ylabel(r'y=' + f'{y_true[sample_id].item()}; ' + r'$\hat{y}=$' + f'{preds[sample_id]}')
    axs[i][0].set_xticks([])
    axs[i][0].set_yticks([])

    for c in range(y_prob.shape[-1]):
        axs[i][c+1].hist(y_prob[:,sample_id, c].numpy(), bins=20, range=(0,1))
        axs[i][c+1].set_yticks([])



fig.suptitle(f"histogram of the confidences in each individual class\n{sample_name}", fontsize=20)
plt.show()


In [None]:
for sample_name, sample_ids  in condition_dict.items():
    fig, axs = plt.subplots(AMOUNT_IN_EACH_CONDITION, y_prob.shape[-1] + 1) # number of possible classes
    fig.set_size_inches(14, AMOUNT_IN_EACH_CONDITION * 1.2)

    for i, sample_id in enumerate(sample_ids):
            
        axs[i][0].imshow(invImageNetNorm(x[sample_id]).permute(1,2,0))
        axs[i][0].set_ylabel(y_true[sample_id].item())
        axs[i][0].set_xticks([])
        axs[i][0].set_yticks([])

        for c in range(y_prob.shape[-1]):
            axs[i][c+1].hist(y_prob[:,sample_id, c].numpy(), bins=20, range=(0,1))
            axs[i][c+1].set_yticks([])



    fig.suptitle(f"histogram of the confidences in each individual class\n{sample_name}", fontsize=20)
    plt.show()


In [None]:
# for sample_name, sample_ids  in condition_dict.items():
#     fig, axs = plt.subplots(AMOUNT_IN_EACH_CONDITION, 2) # number of possible classes
#     fig.set_size_inches(14, AMOUNT_IN_EACH_CONDITION * 1.2)

#     for i, sample_id in enumerate(sample_ids):


#         axs[i][0].imshow(invImageNetNorm(x[sample_id]).permute(1,2,0))
#         axs[i][0].set_ylabel(y_true[sample_id].item())
#         axs[i][0].set_xticks([])
#         axs[i][0].set_yticks([])


#         probs = y_prob[:, sample_id, :]
#         m = probs.mean(dim=0)
#         v = probs.std(dim=0)

#         axs[i][1].bar(range(y_prob.shape[-1]), m, yerr=v)

#         axs[i][1].set_ylim([0,1])



#     fig.suptitle(f"Posterior predictive distributions with per class variances of the confidence\n{sample_name}", fontsize=20)
#     plt.show()


In [None]:

# for sample_name, sample_ids  in condition_dict.items():
#     fig, axs = plt.subplots(AMOUNT_IN_EACH_CONDITION, 3) # number of possible classes
#     fig.set_size_inches(14, AMOUNT_IN_EACH_CONDITION * 1.2)

#     for i, sample_id in enumerate(sample_ids):
#         axs[i][0].imshow(invImageNetNorm(x[sample_id]).permute(1,2,0))
#         axs[i][0].set_ylabel(y_true[sample_id].item())
#         axs[i][0].set_xticks([])
#         axs[i][0].set_yticks([])


#         probs = y_prob[:, sample_id, :]
#         m = probs.mean(dim=0)
#         v = probs.std(dim=0)

#         axs[i][1].bar(range(y_prob.shape[-1]), m, yerr=v)

#         axs[i][1].set_ylim([0,1])

#         mat = axs[i][2].matshow(covariances[sample_id])

#     # plt.matshow(df.corr(), fignum=fig.number)
#     # plt.xticks(range(df.select_dtypes(['number']).shape[1]), df.select_dtypes(['number']).columns, fontsize=14, rotation=45)
#     # plt.yticks(range(df.select_dtypes(['number']).shape[1]), df.select_dtypes(['number']).columns, fontsize=14)
#     cb = plt.colorbar(mat)
#     # cb.ax.tick_params(labelsize=14)



#     fig.suptitle(f"Posterior predictive distributions with per class variances of the confidence and covariance matrix\n{sample_name}", fontsize=10)
#     plt.show()


In [None]:
# TODO plot histogram along each class