In [None]:
import pickle
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Times New Roman"
import seaborn as sns

%matplotlib inline
import torch

In [None]:
import io

class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else: return super().find_class(module, name)

In [None]:
filepath = "sublayer_swag_results"
model_codes = ["00000000", "00000001", "00000111", "00011111", "01111111", "00011000", "00111100", "01111110", "10000000", "11100000", "11111000", "11111110", "10000001", "11111111"]
model_names = ["MAP", "Output", "Output+", "Output++", "Output+++", "Bottleneck", "Bottleneck+", "Bottleneck++",  "Input", "Input+", "Input++", "Input+++", "Input+Output",  "All"]

# load results dataframe

In [None]:
all_results_df_list_swag = []

for model_code, model_name in zip(model_codes, model_names):
    for seed in list(range(0,10)):
        try:
            results = CPU_Unpickler(open(f"{filepath}/{model_code}_s{seed}_val.pkl", "rb")).load()
        except FileNotFoundError:
            print(f"Skipping one of {model_name} seed {seed}")

        for name, corr_dict in results["test_results"].items():
            if len(corr_dict["acc"]) == 1:
                all_results_df_list_swag.append({
                    "dataset": name,
                    "intensity": "No corruption",
                    "acc": 100*corr_dict["acc"][0],
                    "ece": 100*corr_dict["ece"][0],
                    "nll": corr_dict["nll"][0],
                    "model_code": model_code,
                    "model_name": model_name,
                    "seed": seed,
                    "best_lr": results["best_lr"] if model_name != "MAP" else 1,
                })
            else:
                for i in range(len(corr_dict["acc"])):
                    all_results_df_list_swag.append({
                    "dataset": name,
                    "intensity": 1+i,
                    "acc": 100*corr_dict["acc"][i],
                    "ece": 100*corr_dict["ece"][i],
                    "nll": corr_dict["nll"][i],
                    "model_code": model_code,
                    "model_name": model_name,
                    "seed": seed,
                    "best_lr": results["best_lr"] if model_name != "MAP" else 1,
                })
                    
all_results_df_swag = pd.DataFrame(all_results_df_list_swag)
all_results_df_swag = all_results_df_swag.groupby(["model_name", "seed", "intensity"]).mean()
all_results_df_swag = all_results_df_swag.reset_index()

In [None]:
all_model_names = ["MAP", "Output", "Output+", "Output++", "Output+++", "Bottleneck", "Bottleneck+", 
                   "Bottleneck++",  "Input", "Input+", "Input++", "Input+++", "Input+Output",  "All"]

In [None]:
sns.set_theme(style="white", font="Times New Roman", font_scale=1, context="paper")
plt.rcParams['ytick.left'] = True
plt.figure(figsize=(3.5, 1.75), dpi=500)

hue_order = ["MAP", "Output", "Input+", "All"]

labels = ["Deterministic Network (MAP)",
          "Output Layer Stochastic",
          "Input Layer & First\nResNet Block Stochastic",
          "Fully Stochastic",
          ]

colors = [sns.color_palette('colorblind')[-3],
          sns.color_palette('colorblind')[1],
          sns.color_palette('colorblind')[3],
          sns.color_palette('colorblind')[0]
]

def create_subplot(intensity):
    subplot_df = all_results_df_swag[all_results_df_swag["intensity"]==intensity]
    
    mean_fs_nll = (subplot_df[subplot_df["model_name"] == "All"])["nll"].mean()
    
    subplot_df["relative_nll"] = subplot_df["nll"]/mean_fs_nll
    
    s1 = sns.pointplot(x="model_name", y="relative_nll", hue="model_name",
                  data=subplot_df, hue_order=hue_order, order=hue_order, jitter=False, markers='d', palette=colors, 
                  join=False, errwidth=1, scale=0.6, errorbar="std", zorder=2)
    
    plt.axhline(1.0, linestyle="--", color=colors[-1], linewidth=0.5, zorder=1)
    plt.xticks([], fontsize=7)
    plt.xlabel("")
    plt.xlim([-0.5, 3.5])
    plt.ylabel("")
    plt.ylim((0.75, 1.25))
    plt.yticks(fontsize=8)


plt.subplot(141)
create_subplot("No corruption")
plt.ylabel("Relative negative\nlog likelihood ($\downarrow$)", fontsize=7)
plt.title("No corruption", fontsize=7)
plt.gca().get_legend().remove()
plt.yticks([0.75, 1.0, 1.25], ["0.75", "1.0", "1.25"], fontsize=7)

plt.subplot(142)
create_subplot(1)
plt.yticks([0.75, 1.0, 1.25], ["0.75", "1.0", "1.25"], fontsize=0)
plt.title("Intensity 1", fontsize=7)
plt.gca().get_legend().remove()

plt.subplot(143)
create_subplot(3)
plt.yticks([0.75, 1.0, 1.25], ["0.75", "1.0", "1.25"], fontsize=0)
plt.title("Intensity 3", fontsize=7)
plt.gca().get_legend().remove()

plt.subplot(144)
create_subplot(5)
plt.yticks([0.75, 1.0, 1.25], ["0.75", "1.0", "1.25"], fontsize=0)
plt.title("Intensity 5", fontsize=7)

handles, _ = plt.gca().get_legend_handles_labels()
plt.legend(handles[:len(hue_order)], labels, fancybox=True, shadow=True, ncol=2, 
            fontsize=6, loc="upper center", bbox_to_anchor=(-1.5, -0.1), handletextpad=0.1)
plt.tight_layout(w_pad=0.2)
plt.savefig("SWAG_subnetwork_relative.pdf", bbox_inches='tight')

# SWAG extra configs

In [None]:
all_model_names = ["MAP", "Output", "Output+", "Output++", "Output+++", "Bottleneck", "Bottleneck+", 
                   "Bottleneck++",  "Input", "Input+", "Input++", "Input+++", "Input+Output",  "All"]

In [None]:
sns.color_palette('colorblind')

In [None]:
sns.set_theme(style="white", font="Times New Roman", font_scale=1, context="paper")
plt.rcParams['ytick.left'] = True
plt.figure(figsize=(6.75, 1.75), dpi=500)

hue_order = ["MAP", "Output", "Output+", "Input", "Input+", "Input+Output", "Bottleneck", "All"]

labels = hue_order

colors = [sns.color_palette('colorblind')[-3],
          sns.color_palette('colorblind')[1],
          sns.color_palette('colorblind')[1],
          sns.color_palette('colorblind')[3],
          sns.color_palette('colorblind')[3],
          sns.color_palette('colorblind')[4],
          sns.color_palette('colorblind')[5],
          sns.color_palette('colorblind')[0]
]

markers = ["d", "d", "*", "d", "*", "d", "d", "d"]

def create_subplot(intensity):
    subplot_df = all_results_df_swag[all_results_df_swag["intensity"]==intensity]
    
    mean_fs_nll = (subplot_df[subplot_df["model_name"] == "All"])["nll"].mean()
    
    subplot_df["relative_nll"] = subplot_df["nll"]/mean_fs_nll
    
    s1 = sns.pointplot(x="model_name", y="relative_nll", hue="model_name",
                  data=subplot_df, hue_order=hue_order, order=hue_order, jitter=False, markers=markers, palette=colors, 
                  join=False, errwidth=1, scale=0.6, errorbar="std", zorder=2)
    
    plt.axhline(1.0, linestyle="--", color=colors[-1], linewidth=0.5, zorder=1)
    plt.xticks([], fontsize=7)
    plt.xlabel("")
    plt.xlim([-0.5, len(hue_order)-0.5])
    plt.ylabel("")
    plt.ylim((0.75, 1.25))
    plt.yticks(fontsize=8)


plt.subplot(141)
create_subplot("No corruption")
plt.ylabel("Relative negative\nlog likelihood ($\downarrow$)", fontsize=7)
plt.title("No corruption", fontsize=7)
plt.gca().get_legend().remove()
plt.yticks([0.75, 1.0, 1.25], ["0.75", "1.0", "1.25"], fontsize=7)

plt.subplot(142)
create_subplot(1)
plt.yticks([0.75, 1.0, 1.25], ["0.75", "1.0", "1.25"], fontsize=0)
plt.title("Intensity 1", fontsize=7)
plt.gca().get_legend().remove()

plt.subplot(143)
create_subplot(3)
plt.yticks([0.75, 1.0, 1.25], ["0.75", "1.0", "1.25"], fontsize=0)
plt.title("Intensity 3", fontsize=7)
plt.gca().get_legend().remove()

plt.subplot(144)
create_subplot(5)
plt.yticks([0.75, 1.0, 1.25], ["0.75", "1.0", "1.25"], fontsize=0)
plt.title("Intensity 5", fontsize=7)

handles, _ = plt.gca().get_legend_handles_labels()
plt.legend(handles[:len(hue_order)], labels, fancybox=True, shadow=True, ncol=9, 
            fontsize=6, loc="upper center", bbox_to_anchor=(-1.5, -0.1), handletextpad=0.1, title_fontsize=8, title="Network")
plt.tight_layout(w_pad=0.2)
plt.savefig("SWAG_subnetwork_relative_more.pdf", bbox_inches='tight')

# laplace

In [None]:
def load_results_into_list(results_dict, results_list, model_name):
    for name, corr_dict in results["test_results"].items():
            if len(corr_dict["acc"]) == 1:
                results_list.append({
                    "dataset": name,
                    "intensity": "No corruption",
                    "acc": 100*corr_dict["acc"][0],
                    "ece": 100*corr_dict["ece"][0],
                    "nll": corr_dict["nll"][0],
                    "model_name": model_name,
                    "seed": seed,
                })
            else:
                for i in range(len(corr_dict["acc"])):
                    results_list.append({
                    "dataset": name,
                    "intensity": 1+i,
                    "acc": 100*corr_dict["acc"][i],
                    "ece": 100*corr_dict["ece"][i],
                    "nll": corr_dict["nll"][i],
                    "model_name": model_name,
                    "seed": seed,
                })

                    
all_results_df_list_laplace = []

laplace_model_codes = ['00000000', '10000001', '11111111', '00000001']
laplace_model_names = ['MAP', 'Input/Output', 'All', 'Output']
filepath = "sublayer_laplace_results"

for seed in list(range(0, 10)):
    for model_code, model_name in zip(laplace_model_codes, laplace_model_names):
        try:
            results = CPU_Unpickler(open(f"{filepath}/{model_code}_s{seed}.pkl", "rb")).load()
            load_results_into_list(results, all_results_df_list_laplace, f"{model_name}_MML")
        except FileNotFoundError:
            print(f"Skipping one of {model_name} seed {seed} MML Tuning")
            
        try:
            if model_name == "MAP":
                continue
                
            results = CPU_Unpickler(open(f"{filepath}/{model_code}_s{seed}_val.pkl", "rb")).load()
            load_results_into_list(results, all_results_df_list_laplace, f"{model_name}_CV")
        except FileNotFoundError:
            print(f"Skipping one of {model_name} seed {seed} CV Tuning")
    
    try:
        results = CPU_Unpickler(open(f"{filepath}/subnetwork_4096_s{seed}.pkl", "rb")).load()
        load_results_into_list(results, all_results_df_list_laplace, f"Subnetwork4096_MML")
    except FileNotFoundError:
        print(f"Skipping one of subnetwork 4096 seed {seed} MML Tuning")
    
    try:
        results = CPU_Unpickler(open(f"{filepath}/subnetwork_4096_s{seed}_val.pkl", "rb")).load()
        load_results_into_list(results, all_results_df_list_laplace, f"Subnetwork4096_CV")
    except FileNotFoundError:
        print(f"Skipping one of subnetwork 4096 seed {seed} CV")
                    
all_results_df_laplace = pd.DataFrame(all_results_df_list_laplace)
all_results_df_laplace = all_results_df_laplace.groupby(["model_name", "seed", "intensity"]).mean()
all_results_df_laplace = all_results_df_laplace.reset_index()

In [None]:
plot_order = ["No corruption", 1, 2, 3, 4, 5]

sns.set_theme(style="white", font="Times New Roman", font_scale=1, context="paper")
plt.rcParams['ytick.left'] = True
plt.figure(figsize=(3.5, 1.75), dpi=500)

hue_order = ['MAP_MML', 
             'Subnetwork4096_CV',
             'Input/Output_CV', 
             'All_CV']
labels = ["Deterministic Network (MAP)",
          "SWAG Subnetwork Stochastic", 
          "Input & Output Layers Stochastic", 
          "Fully Stochastic"]

markers = ["d", "d", "d", "d"]
colors = [sns.color_palette('colorblind')[-3], 
          sns.color_palette('colorblind')[3], 
          sns.color_palette('colorblind')[1], 
          sns.color_palette('colorblind')[0]]

def create_subplot(intensity):
    subplot_df = all_results_df_laplace[all_results_df_laplace["intensity"]==intensity]
    
    mean_fs_nll = (subplot_df[subplot_df["model_name"] == "All_CV"])["nll"].mean()
    
    subplot_df["relative_nll"] = subplot_df["nll"]/mean_fs_nll
    
    s1 = sns.pointplot(x="model_name", y="relative_nll", hue="model_name",
                  data=subplot_df, hue_order=hue_order, order=hue_order, jitter=False, markers='d', palette=colors, 
                  join=False, errwidth=1, scale=0.6, errorbar="std", zorder=2)
    
    plt.axhline(1.0, linestyle="--", color=colors[-1], linewidth=0.5, zorder=1)
    plt.xticks([], fontsize=7)
    plt.xlabel("")
    plt.xlim([-0.5, 4.5])
    plt.ylabel("")
    
    plt.ylim((0.8, 1.75))
    plt.yticks(fontsize=8)


plt.subplot(141)
create_subplot("No corruption")
plt.ylabel("Relative negative\nlog likelihood ($\downarrow$)", fontsize=7)
plt.title("No corruption", fontsize=7)
plt.gca().get_legend().remove()

plt.subplot(142)
create_subplot(1)
plt.yticks([1, 1.25, 1.5, 1.75], ["", "", "", ""])
plt.title("Intensity 1", fontsize=7)
plt.gca().get_legend().remove()

plt.subplot(143)
create_subplot(3)
plt.yticks([1, 1.25, 1.5, 1.75], ["", "", "", ""])
plt.title("Intensity 3", fontsize=7)
plt.gca().get_legend().remove()

plt.subplot(144)
create_subplot(5)
plt.yticks([1, 1.25, 1.5, 1.75], ["", "", "", ""])
plt.title("Intensity 5", fontsize=7)


# plt.tight_layout(w_pad=0.2)
# plt.suptitle("CIFAR10 - Laplace Approximation", fontsize=8)
handles, _ = plt.gca().get_legend_handles_labels()
plt.legend(handles[:len(hue_order)], labels, fancybox=True, shadow=True, ncol=2, 
            fontsize=6, loc="upper center", bbox_to_anchor=(-1.5, -0.1), handletextpad=0.1)
plt.savefig("laplace_with_subnetwork_relative.pdf", bbox_inches='tight')