# Analyze the Posterior Samples of the last layer in a simple architectures

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sys
import pandas as pd
sys.path.append('../')
from experiments.fcn_bnns.utils.analysis_utils import *
from module_sandbox.utils import (  # noqa: E402
    mse,
)
%load_ext autoreload
%autoreload 2

In [None]:
# datet = "2023-12-22-16-22-58" # Architecture 2
# datet = "2024-01-18-08-29-03" # Architecture 16-16-2
datet = "2024-01-19-12-42-31" # Huge prior 1000
CONFIG_PATH = f'../results/fcn_bnns/{datet}/config.yaml'
DATA_PATH = '../data'
replication = 1
exp_names = get_exp_names(path=f"../results/fcn_bnns/{datet}/")
# exp_name = f'airfoil.data|tanh|2|12|8000|False|NUTS_large|{replication}|1|Normal' # 2
# exp_name = f'airfoil.data|tanh|16-16-2|10|8000|False|NUTS_large|{replication}|1|Normal' # 16-16-2
exp_name = f'airfoil.data|tanh|16-16-2|12|8000|False|NUTS_large|{replication}|100|Normal' # huge prior
exp_name = [ename for ename in exp_names if exp_name in ename][0]
print(exp_name)

In [None]:
exp_info = extract_exp_info(exp_name)
config = load_config(CONFIG_PATH)
n_chains = int(exp_info['n_chains'])
n_samples = int(exp_info['n_samples'])
X_train, Y_train = load_data(exp_info, splittype='train', data_path=DATA_PATH)
X_val, Y_val = load_data(exp_info, splittype='val', data_path=DATA_PATH)
val_threshold = min(1000, X_val.shape[0])
X_val = X_val[:val_threshold, :]
Y_val = Y_val[:val_threshold, :]
linear_regr, rf_regr = fit_baselines(X_train, Y_train)
mse_linear, mse_rf = evaluate_baselines(X_val, Y_val, linear_regr, rf_regr)
res_dict = {}
res_dict['rmse_linear'] = np.sqrt(mse_linear)
res_dict['rmse_rf'] = np.sqrt(mse_rf)
posterior_samples, posterior_samples_raw = load_samples(exp_name, f'../results/fcn_bnns/{datet}')
model = load_model(exp_name, f'../results/fcn_bnns/{datet}')
preds_chain_dim, preds = get_posterior_predictive(
    model, posterior_samples_raw, X_val, exp_info['n_chains']
)
rmse_per_chain = {}
for i in range(preds_chain_dim.shape[0]):
    rmse_per_chain[f'chain_{i}'] = np.sqrt(mse(preds_chain_dim[i], Y_val)[0])
rmse_table = pd.DataFrame(rmse_per_chain, index=['RMSE']).T
bad_chains = rmse_table[rmse_table['RMSE'] > np.sqrt(mse_linear)].index
bad_chains = bad_chains.str.split('_').str[1].astype(int).values
bad_chains = bad_chains.tolist()
good_chains = [i for i in range(n_chains) if i not in bad_chains]
if len(good_chains) > 0:
    good_chains_pred_indices = np.concatenate(
        [np.arange(n_samples) + (n_samples * i) for i in good_chains]
    )
    good_chains_pred_indices_100 = np.concatenate(
        [np.arange(100) + (n_samples * i) for i in good_chains]
    )
res_dict['n_bad_chains'] = len(bad_chains)
res_dict['n_good_chains'] = len(good_chains)
if len(good_chains) == 0:
    res_dict['rmse_good_chains'] = np.nan
    res_dict['rmse_good_chains_100'] = np.nan
    res_dict['acc_90hpdi'] = np.nan
    res_dict['acc_90hpdi_100'] = np.nan
else:
    # RMSE
    res_dict['rmse_good_chains'] = np.sqrt(
        mse(preds[good_chains_pred_indices, :], Y_val)[0]
    )
    res_dict['rmse_good_chains_100'] = np.sqrt(
        mse(preds[good_chains_pred_indices_100, :], Y_val)[0]
    )

res_dict

In [None]:
truncate_samples = config["n_samples"]
rmse_per_chain = {}
for i in range(preds_chain_dim.shape[0]):
    rmse_per_chain[f'chain_{i}'] = np.sqrt(mse(preds_chain_dim[i], Y_val)[0])
rmse_table = pd.DataFrame(rmse_per_chain, index=['RMSE']).T
bad_chains = rmse_table[rmse_table['RMSE'] > res_dict["rmse_linear"]].index
bad_chains = bad_chains.str.split('_').str[1].astype(int).values
bad_chains = bad_chains.tolist()
good_chains = [i for i in range(n_chains) if i not in bad_chains]
n_samples = config['n_samples']
good_chains_pred_indices = np.concatenate(
    [np.arange(truncate_samples) + (n_samples * i) for i in good_chains]
)
rmse_table = rmse_table.sort_values(by='RMSE', ascending=True)
# all entries of the df that are > np.sqrt(mse_linear_model) should get a red
# background
def color_cells(x):
    """Color the cells of the table."""
    return 'background-color: red' if x > res_dict["rmse_linear"] else ''
rmse_table = rmse_table.style.map(color_cells)
rmse_table = rmse_table.format('{:.3f}')
rmse_table

In [None]:
for parameter in posterior_samples.keys():
    print(parameter)
    print(posterior_samples[parameter].shape)

We are now interested in the last layer specifically the posterior samples that are connected to the estimation of the $\mu$ in the last layer. 


In [None]:
# Architecture 2
# last_layer_mu_weight_samples = posterior_samples["W2"][:, :, :, 0]
# print(last_layer_mu_weight_samples.shape)
# last_layer_mu_bias_samples = posterior_samples["b2"][:, :, 0]
# print(last_layer_mu_bias_samples.shape)
# Architecture 16-16-2
last_layer_mu_weight_samples = posterior_samples["W4"][:, :, :, 0][good_chains, ...]
print(last_layer_mu_weight_samples.shape)
last_layer_mu_bias_samples = posterior_samples["b4"][:, :, 0][good_chains, ...]
print(last_layer_mu_bias_samples.shape)

In [None]:
last_layer_mu_weight_samples.mean(axis=1)

In [None]:
sns.lineplot(
    x=np.arange(last_layer_mu_weight_samples.shape[1]),
    y=last_layer_mu_weight_samples[0, :, 0],
    label='chain 0'
)

In [None]:
# Biases
for chain in range(last_layer_mu_bias_samples.shape[0]):
    sns.displot(last_layer_mu_bias_samples[chain, :])
    sns.rugplot(last_layer_mu_bias_samples[chain, :])
    # add title
    plt.title(f"Chain {chain}")
    # plt.show()

In [None]:
# flatten the first two dimensions
last_layer_mu_bias_samples_flat = last_layer_mu_bias_samples.flatten()
chain_indices = np.concatenate(
    [np.repeat(i, last_layer_mu_bias_samples.shape[1]) for i in range(len(good_chains))]
)
sns.displot(
    x=last_layer_mu_bias_samples_flat,
    hue=chain_indices,
    rug=False,
    kde=True,
    bins =60,
    stat="density",
    alpha=0.6,
)

In [None]:
last_layer_mu_weight_samples_flat = last_layer_mu_weight_samples[:,:2000,:].reshape(
    last_layer_mu_weight_samples.shape[0] * 2000,
    last_layer_mu_weight_samples.shape[2],
)
chain_indices = np.concatenate(
    [np.repeat(i, 2000) for i in range(len(good_chains))]
)
fig = sns.jointplot(
    x = last_layer_mu_weight_samples_flat[:, 0],
    y = last_layer_mu_weight_samples_flat[:, 1],
    alpha=0.02,
    kind="scatter",
    hue=chain_indices,
    palette="viridis",
)
fig.plot_joint(sns.kdeplot, color=chain_indices, zorder=0, levels=6)
fig.set_axis_labels("Output Layer $w_1$", "Output Layer $w_2$")
fig.ax_joint.set_xlim(-5, 5)
fig.ax_joint.set_ylim(-5, 5)
plt.legend([],[], frameon=False)

In [None]:
# for the weights use a bivariate density between the two dimensions of the last layer for each chain
# like this: g = sns.jointplot(data=penguins, x="bill_length_mm", y="bill_depth_mm")
# g.plot_joint(sns.kdeplot, color="r", zorder=0, levels=6)
# g.plot_marginals(sns.rugplot, color="r", height=-.15, clip_on=False)
for chain in range(last_layer_mu_weight_samples.shape[0]):
    g = sns.jointplot(
        x=last_layer_mu_weight_samples[chain, :, 0],
        y=last_layer_mu_weight_samples[chain, :, 1],
        alpha=0.1,
    )
    g.plot_joint(sns.kdeplot, color="r", zorder=0, levels=6)
    g.plot_marginals(sns.rugplot, color="r", height=-.15, clip_on=False)
    plt.show()