# Analyze the slopes of the chainwise posterior samples

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sys
import pandas as pd
sys.path.append('../')
from experiments.fcn_bnns.utils.analysis_utils import *
from module_sandbox.utils import (  # noqa: E402
    mse,
)
%load_ext autoreload
%autoreload 2

In [None]:
datet = "2023-12-22-16-22-58" # Architecture shallow
# datet = "2024-01-16-13-57-52" # Architecture deep
CONFIG_PATH = f'../results/fcn_bnns/{datet}/config.yaml'
DATA_PATH = '../data'
replication = 1
exp_names = get_exp_names(path=f"../results/fcn_bnns/{datet}/")
exp_name = f'airfoil.data|tanh|16-16|12|8000|False|NUTS_large|{replication}|1|Normal' # shallow
# exp_name = f'airfoil.data|tanh|8-8-8-8-8-8|12|8000|False|NUTS_large|{replication}|1|Normal' # deep
exp_name = [ename for ename in exp_names if exp_name in ename][0]
print(exp_name)

In [None]:
exp_info = extract_exp_info(exp_name)
config = load_config(CONFIG_PATH)
n_chains = int(exp_info['n_chains'])
n_samples = int(exp_info['n_samples'])
X_train, Y_train = load_data(exp_info, splittype='train', data_path=DATA_PATH)
X_val, Y_val = load_data(exp_info, splittype='val', data_path=DATA_PATH)
val_threshold = min(1000, X_val.shape[0])
X_val = X_val[:val_threshold, :]
Y_val = Y_val[:val_threshold, :]
linear_regr, rf_regr = fit_baselines(X_train, Y_train)
mse_linear, mse_rf = evaluate_baselines(X_val, Y_val, linear_regr, rf_regr)
res_dict = {}
res_dict['rmse_linear'] = np.sqrt(mse_linear)
res_dict['rmse_rf'] = np.sqrt(mse_rf)
posterior_samples, posterior_samples_raw = load_samples(exp_name, f'../results/fcn_bnns/{datet}')
model = load_model(exp_name, f'../results/fcn_bnns/{datet}')
preds_chain_dim, preds = get_posterior_predictive(
    model, posterior_samples_raw, X_val, exp_info['n_chains']
)
rmse_per_chain = {}
for i in range(preds_chain_dim.shape[0]):
    rmse_per_chain[f'chain_{i}'] = np.sqrt(mse(preds_chain_dim[i], Y_val)[0])
rmse_table = pd.DataFrame(rmse_per_chain, index=['RMSE']).T
bad_chains = rmse_table[rmse_table['RMSE'] > np.sqrt(mse_linear)].index
bad_chains = bad_chains.str.split('_').str[1].astype(int).values
bad_chains = bad_chains.tolist()
good_chains = [i for i in range(n_chains) if i not in bad_chains]
if len(good_chains) > 0:
    good_chains_pred_indices = np.concatenate(
        [np.arange(n_samples) + (n_samples * i) for i in good_chains]
    )
    good_chains_pred_indices_100 = np.concatenate(
        [np.arange(100) + (n_samples * i) for i in good_chains]
    )
res_dict['n_bad_chains'] = len(bad_chains)
res_dict['n_good_chains'] = len(good_chains)
if len(good_chains) == 0:
    res_dict['rmse_good_chains'] = np.nan
    res_dict['rmse_good_chains_100'] = np.nan
    res_dict['acc_90hpdi'] = np.nan
    res_dict['acc_90hpdi_100'] = np.nan
else:
    # RMSE
    res_dict['rmse_good_chains'] = np.sqrt(
        mse(preds[good_chains_pred_indices, :], Y_val)[0]
    )
    res_dict['rmse_good_chains_100'] = np.sqrt(
        mse(preds[good_chains_pred_indices_100, :], Y_val)[0]
    )

res_dict

In [None]:
truncate_samples = config["n_samples"]
rmse_per_chain = {}
for i in range(preds_chain_dim.shape[0]):
    rmse_per_chain[f'chain_{i}'] = np.sqrt(mse(preds_chain_dim[i], Y_val)[0])
rmse_table = pd.DataFrame(rmse_per_chain, index=['RMSE']).T
bad_chains = rmse_table[rmse_table['RMSE'] > res_dict["rmse_linear"]].index
bad_chains = bad_chains.str.split('_').str[1].astype(int).values
bad_chains = bad_chains.tolist()
good_chains = [i for i in range(n_chains) if i not in bad_chains]
n_samples = config['n_samples']
good_chains_pred_indices = np.concatenate(
    [np.arange(truncate_samples) + (n_samples * i) for i in good_chains]
)
rmse_table = rmse_table.sort_values(by='RMSE', ascending=True)
# all entries of the df that are > np.sqrt(mse_linear_model) should get a red
# background
def color_cells(x):
    """Color the cells of the table."""
    return 'background-color: red' if x > res_dict["rmse_linear"] else ''
rmse_table = rmse_table.style.map(color_cells)
rmse_table = rmse_table.format('{:.3f}')
rmse_table

In [None]:
for parameter in posterior_samples.keys():
    print(parameter)
    print(posterior_samples[parameter].shape)

In [None]:
bias_slopes = {}
weight_slopes = {}
for parameter in posterior_samples.keys():
    if 'b' in parameter:
        bias_slopes[parameter] = jnp.apply_along_axis(
            fit_slope, 1, posterior_samples[parameter][good_chains, ...]
        ).flatten()
    elif 'W' in parameter:
        weight_slopes[parameter] = jnp.apply_along_axis(
            fit_slope, 1, posterior_samples[parameter][good_chains, ...]
        ).flatten()

bias_slopes_mean = {
    key: np.mean(bias_slopes[key]).item() for key in bias_slopes.keys()
}
bias_slopes_std = {
    key: np.std(bias_slopes[key]).item() for key in bias_slopes.keys()
}
weight_slopes_mean = {
    key: np.mean(weight_slopes[key]).item() for key in weight_slopes.keys()
}
weight_slopes_std = {
    key: np.std(weight_slopes[key]).item() for key in weight_slopes.keys()
}

fig = plt.figure(figsize=(10, 5))
# plot as lineplots with errorbars over the layers
# lineplot for the biases
layers = np.arange(1, len(weight_slopes)+1)
sns.lineplot(
    x=layers,
    y=[weight_slopes_mean[f'W{i}'] for i in layers],
    marker='o',
    label='Weights',
    color='#06238f',
)
plt.errorbar(
    x=layers,
    y=[weight_slopes_mean[f'W{i}'] for i in layers],
    yerr=[weight_slopes_std[f'W{i}'] for i in layers],
    fmt='none',
    ecolor='#06238f',
    capsize=5,
)
sns.lineplot(
    x=layers,
    y=[bias_slopes_mean[f'b{i}'] for i in layers],
    marker='o',
    label='Biases',
    color='#2e59f2',
)
plt.errorbar(
    x=layers,
    y=[bias_slopes_mean[f'b{i}'] for i in layers],
    yerr=[bias_slopes_std[f'b{i}'] for i in layers],
    fmt='none',
    ecolor='#2e59f2',
    capsize=5,
)
plt.xlabel('Layer')
plt.ylabel('Fitted Slopes')
# discrete xticks
plt.xticks(layers)
# title
plt.title('Fitted Slopes on the posterior samples of the Biases and Weights')




In [None]:
fig, axs = plt.subplots(1, 2, figsize=(20, 5))

# Plot the weight slopes
weight_data = [weight_slopes[f'W{i}'] for i in layers]
weight_labels = [f'W{i}' for i in layers]
sns.boxplot(data=weight_data, ax=axs[0], color='#06238f', width=0.3)
# add horizontal line at 0
axs[0].axhline(y=0, color='black', linestyle='--')
axs[0].axhspan(-1, 1, alpha=0.2, color='grey')
axs[0].set_xticklabels(weight_labels)
axs[0].set_xlabel('Weights')
axs[0].set_ylabel('Fitted Slopes')
axs[0].set_title('Fitted Slopes on the posterior samples of the Weights')

# Plot the bias slopes
bias_data = [bias_slopes[f'b{i}'] for i in layers]
bias_labels = [f'b{i}' for i in layers]
sns.boxplot(data=bias_data, ax=axs[1], color='#2e59f2', width=0.3)
axs[1].axhline(y=0, color='black', linestyle='--')
axs[1].axhspan(-1, 1, alpha=0.2, color='grey')
axs[1].set_xticklabels(bias_labels)
axs[1].set_xlabel('Biases')
axs[1].set_title('Fitted Slopes on the posterior samples of the Biases')

plt.tight_layout()
plt.show()
