In [1]:
import os
os.getcwd()

'/mnt/qb/work/hennig/hmx148/MastersThesisCode/laplace-redux'

In [2]:
import argparse
import yaml

import torch
import pycalib
from laplace import Laplace

import utils.data_utils as du
import utils.wilds_utils as wu
import utils.utils as util
from utils.test import test
from marglik_training.train_marglik import get_backend

# import warnings
# warnings.filterwarnings('ignore')

from argparse import Namespace

from tqdm import tqdm

import matplotlib.pyplot as plt

from copy import deepcopy

from random import randint

import numpy as np

In [3]:
from tueplots import bundles



# Inspired by bundles.neurips2023(), but adapting font sizes for pt12 standard

settings_dict = {'text.usetex': True,
                 'font.family': 'serif',
                 'text.latex.preamble': '\\renewcommand{\\rmdefault}{ptm}\\renewcommand{\\sfdefault}{phv}',
                 'figure.figsize': (5.5, 3.399186938124422),
                 'figure.constrained_layout.use': True,
                 'figure.autolayout': False,
                 'savefig.bbox': 'tight',
                 'savefig.pad_inches': 0.015,
                 'font.size': 10,
                 'axes.labelsize': 10,
                 'legend.fontsize': 8,
                 'xtick.labelsize': 8,
                 'ytick.labelsize': 8,
                 'axes.titlesize': 10,
                 'figure.dpi': 300}


plt.rcParams.update(settings_dict)


# Can use colors from bundles.rgb.
#     tue_blue
#     tue_brown
#     tue_dark
#     tue_darkblue
#     tue_darkgreen
#     tue_gold
#     tue_gray
#     tue_green
#     tue_lightblue
#     tue_lightgold
#     tue_lightgreen
#     tue_lightorange
#     tue_mauve
#     tue_ocre
#     tue_orange
#     tue_red
#     tue_violet

In [4]:
from torchvision import transforms

def invImageNetNorm(x):
    """ Inverts the Normalization given by:
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]) """
    invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                     std = [ 1., 1., 1. ]),
                               ])

    return invTrans(x)

In [5]:
# make Plots with images from the test set
# TOP: ID
# Bottom: OOD

# Left to right: different classes


In [6]:
# Make overview table of the datasets:
# Top->bottom: train, val test set
# Left->right classes, total
# entries: number of objects (percentage)

In [7]:
def print_dataset_statistics(class_names, train_loader=None, IDval_loader=None, IDtest_loader=None, OODval_loader=None, OODtest_loader=None):
    print('\\begin{table}\n\\begin{center}\n\\begin{tabular}{' + "l|" + "c" * (len(class_names) + 1) + '}')
    print("    " + " Dataset & " + " & ".join(class_names) + " & total" + " \\\\\n    \\hline")
    for loader, name in zip([train_loader, IDval_loader, IDtest_loader, OODval_loader, OODtest_loader],
                            ['train (ID)', 'val (ID)', 'test (ID)', 'val (OOD)', 'test (OOD)']):
        if loader:
            labels = []
            for x, y in loader:
                labels.append(y)
            classes, counts = torch.concat(labels).unique(return_counts=True)
            count_proportions = counts / counts.sum()
            count_percentages = count_proportions * 100
            line_string = "    " + name
            for count, percentage in zip(counts, count_percentages):
                line_string += " & " + f'{count} ({percentage:.1f})'
            line_string += f" & {counts.sum()}"
            line_string += " \\\\"
            print(line_string)
        else:
            line_string = "    " + name + (" & " + " -- ") * (len(class_names) + 1) + " \\\\"
            print(line_string)
    print("\\end{tabular}\n\\end{center}\n\\caption{[TODO]}\\label{[TODO]}\n\\end{table}")

In [8]:
dataset = 'camelyon17'
train_loader, IDval_loader, IDtest_loader = wu.get_wilds_loaders(
            dataset, './data', 1.0, 1, download=False, use_ood_val_set=False)
OODtest_loader = wu.get_wilds_ood_test_loader(
            dataset, './data', 1.0)
_, OODval_loader, _ = wu.get_wilds_loaders(
            dataset, './data', 1.0, 1, download=False, use_ood_val_set=True)

camelyon17 dataset doesn't have an in-distribution test split -- using validation split instead!
Using the OOD validation set instead of the ID validation set
camelyon17 dataset doesn't have an in-distribution test split -- using validation split instead!


In [9]:
class_names = ["normal", "tumor"]
print_dataset_statistics(class_names=class_names, train_loader=train_loader, IDval_loader=IDval_loader, IDtest_loader=None, OODval_loader=OODval_loader, OODtest_loader=OODtest_loader)

\begin{table}
\begin{center}
\begin{tabular}{l|ccc}
     Dataset & normal & tumor & total \\
    \hline
    train (ID) & 151046 (49.9) & 151390 (50.1) & 302436 \\
    val (ID) & 16952 (50.5) & 16608 (49.5) & 33560 \\
    test (ID) &  --  &  --  &  --  \\
    val (OOD) & 17452 (50.0) & 17452 (50.0) & 34904 \\
    test (OOD) & 42527 (50.0) & 42527 (50.0) & 85054 \\
\end{tabular}
\end{center}
\caption{[TODO]}\label{[TODO]}
\end{table}


In [10]:
# Skinlesions
train_loader, IDval_loader, IDtest_loader = du.get_ham10000_loaders('./data', batch_size=16, train_batch_size=16, num_workers=4, image_size=512)
OODtest_loader = du.get_SkinLesions_ood_loader(None, data_path='./data', batch_size=16, num_workers=4, image_size=512)

SKINLESIONS_CLASS_TO_IDX = {'akiec': 0, 'bcc': 1, 'bkl': 2, 'df': 3, 'mel': 4, 'nv': 5, 'vasc': 6}
class_names = SKINLESIONS_CLASS_TO_IDX.keys()

print_dataset_statistics(class_names=class_names,
                         train_loader=train_loader,
                         IDval_loader=IDval_loader,
                         IDtest_loader=IDtest_loader,
                         OODtest_loader=OODtest_loader)


\begin{table}
\begin{center}
\begin{tabular}{l|cccccccc}
     Dataset & akiec & bcc & bkl & df & mel & nv & vasc & total \\
    \hline
    train (ID) & 327 (3.3) & 514 (5.1) & 1099 (11.0) & 115 (1.1) & 1113 (11.1) & 6705 (66.9) & 142 (1.4) & 10015 \\
    val (ID) & 8 (4.1) & 15 (7.8) & 22 (11.4) & 1 (0.5) & 21 (10.9) & 123 (63.7) & 3 (1.6) & 193 \\
    test (ID) & 43 (2.8) & 93 (6.2) & 217 (14.4) & 44 (2.9) & 171 (11.3) & 909 (60.1) & 35 (2.3) & 1512 \\
    val (OOD) &  --  &  --  &  --  &  --  &  --  &  --  &  --  &  --  \\
    test (OOD) & 1175 (3.6) & 2926 (9.1) & 1705 (5.3) & 171 (0.5) & 4460 (13.8) & 21708 (67.2) & 169 (0.5) & 32314 \\
\end{tabular}
\end{center}
\caption{[TODO]}\label{[TODO]}
\end{table}


In [11]:
# Amazon
dataset = 'amazon'
train_loader, IDval_loader, IDtest_loader = wu.get_wilds_loaders(
            dataset, './data', 1.0, 1, download=False, use_ood_val_set=False)
OODtest_loader = wu.get_wilds_ood_test_loader(
            dataset, './data', 1.0)
_, OODval_loader, _ = wu.get_wilds_loaders(
            dataset, './data', 1.0, 1, download=False, use_ood_val_set=True)

class_names = ["$\\bigstar \\openbigstar \\openbigstar \\openbigstar \\openbigstar$",
               "$\\bigstar \\bigstar \\openbigstar \\openbigstar \\openbigstar$",
               "$\\bigstar \\bigstar \\bigstar \\openbigstar \\openbigstar$",
               "$\\bigstar \\bigstar \\bigstar \\bigstar \\openbigstar$",
               "$\\bigstar \\bigstar \\bigstar \\bigstar \\bigstar$"]

print_dataset_statistics(class_names=class_names,
                         train_loader=train_loader,
                         IDval_loader=IDval_loader,
                         IDtest_loader=IDtest_loader,
                         OODval_loader=OODval_loader,
                         OODtest_loader=OODtest_loader)


Using the OOD validation set instead of the ID validation set
\begin{table}
\begin{center}
\begin{tabular}{l|cccccc}
     Dataset & $\bigstar \openbigstar \openbigstar \openbigstar \openbigstar$ & $\bigstar \bigstar \openbigstar \openbigstar \openbigstar$ & $\bigstar \bigstar \bigstar \openbigstar \openbigstar$ & $\bigstar \bigstar \bigstar \bigstar \openbigstar$ & $\bigstar \bigstar \bigstar \bigstar \bigstar$ & total \\
    \hline
    train (ID) & 2648 (1.1) & 6745 (2.7) & 22903 (9.3) & 71949 (29.3) & 141257 (57.5) & 245502 \\
    val (ID) & 586 (1.2) & 1323 (2.8) & 4060 (8.6) & 13010 (27.7) & 27971 (59.6) & 46950 \\
    test (ID) & 572 (1.2) & 1304 (2.8) & 4496 (9.6) & 13287 (28.3) & 27291 (58.1) & 46950 \\
    val (OOD) & 1413 (1.4) & 2886 (2.9) & 9315 (9.3) & 27908 (27.9) & 58528 (58.5) & 100050 \\
    test (OOD) & 1643 (1.6) & 3212 (3.2) & 9972 (10.0) & 28258 (28.2) & 56965 (56.9) & 100050 \\
\end{tabular}
\end{center}
\caption{[TODO]}\label{[TODO]}
\end{table}


In [12]:
from PIL import Image 

In [13]:
os.getcwd()

'/mnt/qb/work/hennig/hmx148/MastersThesisCode/laplace-redux'

In [14]:
def save_num_images_from_dataset(class_names, IDtest_loader, OODtest_loader, dataset_name, num_images=20):
    savedir = f"./results/images/{dataset_name}"
    if not os.path.exists(savedir):
        os.makedirs(savedir)

    for loader, condition_name in zip([IDtest_loader, OODtest_loader], ["ID", "OOD"]):
        for class_id, class_name in enumerate(class_names):
            imgs = []
            img_found_num = 0
            for x, y in loader:
                if img_found_num == num_images:
                    break
                if torch.any(y == class_id):
                    img = x[y == class_id][0]
                    imgs.append(img)
                    img_found_num += 1

                    img = invImageNetNorm(img).permute(1,2,0)
                    # img = invImageNetNorm(img) # .permute(1,2,0)

                    img = img.numpy()
                    img = (img * 255).astype('uint8')
                    # print(img)

                    pil_img = Image.fromarray(img, 'RGB')

                    filename = f'{condition_name}_{class_name}_{img_found_num}.png'
                    pil_img.save(os.path.join(savedir, filename), 'PNG')


In [15]:
# Cameylon17
class_names = ["normal", "tumor"]
dataset = 'camelyon17'
train_loader, IDval_loader, IDtest_loader = wu.get_wilds_loaders(
            dataset, './data', 1.0, 1, download=False, use_ood_val_set=False)
OODtest_loader = wu.get_wilds_ood_test_loader(
            dataset, './data', 1.0)
_, OODval_loader, _ = wu.get_wilds_loaders(
            dataset, './data', 1.0, 1, download=False, use_ood_val_set=True)
save_num_images_from_dataset(class_names=class_names, IDtest_loader=IDtest_loader, OODtest_loader=OODtest_loader, dataset_name="Camelyon17")

camelyon17 dataset doesn't have an in-distribution test split -- using validation split instead!
Using the OOD validation set instead of the ID validation set
camelyon17 dataset doesn't have an in-distribution test split -- using validation split instead!


In [16]:

# Skinlesions
train_loader, IDval_loader, IDtest_loader = du.get_ham10000_loaders('./data', batch_size=16, train_batch_size=16, num_workers=4, image_size=512)
OODtest_loader = du.get_SkinLesions_ood_loader(None, data_path='./data', batch_size=16, num_workers=4, image_size=512)

SKINLESIONS_CLASS_TO_IDX = {'akiec': 0, 'bcc': 1, 'bkl': 2, 'df': 3, 'mel': 4, 'nv': 5, 'vasc': 6}
class_names = SKINLESIONS_CLASS_TO_IDX.keys()


save_num_images_from_dataset(class_names=class_names, IDtest_loader=IDtest_loader, OODtest_loader=OODtest_loader, dataset_name="SkinLesions")