# Analyze the posterior samples and their induced models from selected experiments

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from IPython.display import display
import jax.numpy as jnp
import seaborn as sns
import sys
import pandas as pd
from numpyro.diagnostics import hpdi
sys.path.append('../')
from experiments.fcn_bnns.utils.analysis_utils import *
from src.utils import (  # noqa: E402
    mse,
)
from src.visualization.posterior_predictive import (  # noqa: E402
    pp_interchain_means,
    visualize_pp_chain_means,
)
from src.utils import flatten_chain_dimension
from experiments.fcn_bnns.utils.ui_utils import (  # noqa: E402
    calculate_diagnostics,
    plot_sample_paths,
    visualize_ess,
    visualize_pp_rhat,
    visualize_rhat,
)
from src.diagnostics.gelman import split_chain_r_hat, gelman_split_r_hat  # noqa: E402
%load_ext autoreload
%autoreload 2

In [None]:
datet = "YYYY-MM-DD-00-00-00"
CONFIG_PATH = f'../results/fcn_bnns/{datet}/config.yaml'
DATA_PATH = '../data'
replication = 1
exp_names = get_exp_names(path=f"../results/fcn_bnns/{datet}/")

In [None]:
DATASET = 'airfoil'
exp_name = f'{DATASET}.data|tanh|2|2|100|False|NUTS_tiny|{replication}|1|Normal'
exp_name = [ename for ename in exp_names if exp_name in ename][0]
print(exp_name)

In [None]:
exp_info = extract_exp_info(exp_name)
config = load_config(CONFIG_PATH)
n_chains = int(exp_info['n_chains'])
n_samples = int(exp_info['n_samples'])
X_train, Y_train = load_data(exp_info, splittype='train', data_path=DATA_PATH)
X_val, Y_val = load_data(exp_info, splittype='val', data_path=DATA_PATH)
val_threshold = min(1000, X_val.shape[0])
X_val = X_val[:val_threshold, :]
Y_val = Y_val[:val_threshold, :]
linear_regr, rf_regr = fit_baselines(X_train, Y_train)
mse_linear, mse_rf = evaluate_baselines(X_val, Y_val, linear_regr, rf_regr)
res_dict = {}
res_dict['rmse_linear'] = np.sqrt(mse_linear)
res_dict['rmse_rf'] = np.sqrt(mse_rf)
discard_warmup = 10000 if config["keep_warmup"] else 0
posterior_samples, posterior_samples_raw = load_samples(exp_name, f'../results/fcn_bnns/{datet}', discard_warmup=discard_warmup)
model = load_model(exp_name, f'../results/fcn_bnns/{datet}')
preds_chain_dim, preds = get_posterior_predictive(
    model, posterior_samples_raw, X_val, exp_info['n_chains']
)
rmse_per_chain = {}
for i in range(preds_chain_dim.shape[0]):
    rmse_per_chain[f'chain_{i}'] = np.sqrt(mse(preds_chain_dim[i], Y_val)[0])
rmse_table = pd.DataFrame(rmse_per_chain, index=['RMSE']).T
bad_chains = rmse_table[rmse_table['RMSE'] > np.sqrt(mse_linear)].index
bad_chains = bad_chains.str.split('_').str[1].astype(int).values
bad_chains = bad_chains.tolist()
good_chains = [i for i in range(n_chains) if i not in bad_chains]
if len(good_chains) > 0:
    good_chains_pred_indices = np.concatenate(
        [np.arange(n_samples) + (n_samples * i) for i in good_chains]
    )
    good_chains_pred_indices_100 = np.concatenate(
        [np.arange(100) + (n_samples * i) for i in good_chains]
    )
res_dict['n_bad_chains'] = len(bad_chains)
res_dict['n_good_chains'] = len(good_chains)
if len(good_chains) == 0:
    res_dict['rmse_good_chains'] = np.nan
    res_dict['rmse_good_chains_100'] = np.nan
    res_dict['acc_90hpdi'] = np.nan
    res_dict['acc_90hpdi_100'] = np.nan
else:
    # RMSE
    res_dict['rmse_good_chains'] = np.sqrt(
        mse(preds[good_chains_pred_indices, :], Y_val)[0]
    )
    res_dict['rmse_good_chains_100'] = np.sqrt(
        mse(preds[good_chains_pred_indices_100, :], Y_val)[0]
    )
res_dict

# Predictive performance on test set

In [None]:
truncate_samples = config["n_samples"]
# truncate_samples = 1000

In [None]:
rmse_per_chain = {}
for i in range(preds_chain_dim.shape[0]):
    rmse_per_chain[f'chain_{i}'] = np.sqrt(mse(preds_chain_dim[i], Y_val)[0])
rmse_table = pd.DataFrame(rmse_per_chain, index=['RMSE']).T
bad_chains = rmse_table[rmse_table['RMSE'] > res_dict["rmse_linear"]].index
bad_chains = bad_chains.str.split('_').str[1].astype(int).values
bad_chains = bad_chains.tolist()
good_chains = [i for i in range(n_chains) if i not in bad_chains]
n_samples = config['n_samples']
good_chains_pred_indices = np.concatenate(
    [np.arange(truncate_samples) + (n_samples * i) for i in good_chains]
)
rmse_table = rmse_table.sort_values(by='RMSE', ascending=True)
# all entries of the df that are > np.sqrt(mse_linear_model) should get a red
# background
def color_cells(x):
    """Color the cells of the table."""
    return 'background-color: red' if x > res_dict["rmse_linear"] else ''
rmse_table = rmse_table.style.map(color_cells)
rmse_table = rmse_table.format('{:.3f}')
rmse_table

### Analyze performance grid over samples and chains

Also save the data for better plotting in R.

In [None]:
# use all 1,2,3,..,10, 50, 100, 500 and every other 500 until truncate_samples
sample_steps = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
sample_steps += [50, 100]
sample_steps += list(range(500, truncate_samples, 500))
sample_steps += [truncate_samples]
sample_steps = np.unique(sample_steps)
sample_steps = sample_steps[sample_steps <= truncate_samples]
sample_steps = sample_steps.tolist()
# calculate the rmse for each combination of samples and chains
rmse_over_samples_and_chains = []
for n_samples in sample_steps:
    for i, n_chains in enumerate(good_chains):
        rmse_over_samples_and_chains.append(
            np.sqrt(
                mse(
                    preds_chain_dim[good_chains[: i + 1], :n_samples, ...].reshape(
                        -1, *preds_chain_dim.shape[2:]
                    ),
                    Y_val,
                )[0]
            )
        )
# visualize the rmse over samples and chains using a heatmap
rmse_over_samples_and_chains = np.array(rmse_over_samples_and_chains).reshape(
    len(sample_steps), len(good_chains)
)
rmse_over_samples_and_chains = rmse_over_samples_and_chains[::-1, :]
fig = plt.figure(figsize=(10, 6))
sns.heatmap(
    rmse_over_samples_and_chains,
    xticklabels=[c for c in range(1, len(good_chains)+1)],
    yticklabels=sample_steps[::-1],
    cmap='YlGn_r',
)
# find the indices of the minimum value in the heatmap
min_idx = np.unravel_index(
    np.nanargmin(rmse_over_samples_and_chains),
    rmse_over_samples_and_chains.shape,
)
print(min_idx)
# annotate with a red cross
plt.scatter(min_idx[1]+0.5, min_idx[0]+0.5, marker='x', color='white')
plt.xlabel('Number of Chains')
plt.ylabel('Number of Samples (Non-Linear!)')
plt.title(
    (
        'RMSE over Samples and'
        ' Chains (Lower is better)'
    )
)
plt.close(fig)
fig

In [None]:
# save the heatmap
rmse_over_samples_and_chains_df = pd.DataFrame(
    rmse_over_samples_and_chains,
    index=sample_steps[::-1],
    columns=good_chains,
)
rmse_over_samples_and_chains_df.to_csv(
    f'../paper_bde/practical_sbi/chains_samples_grid/{DATASET}_rmse_over_samples_and_chains.csv'
)
rmse_over_samples_and_chains_df

In [None]:
lppd_pointwise = model.get_lppd(X_val, Y_val, posterior_samples_raw, rolling=False)
lppd_pointwise_chain_dim = add_chain_dimension({'lppd': lppd_pointwise}, n_chains=exp_info["n_chains"])['lppd']
ppd_pointwise_chain_dim = jnp.exp(lppd_pointwise_chain_dim)

In [None]:
# use all 1,2,3,..,10, 50, 100, 500 and every other 500 until truncate_samples
sample_steps = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
sample_steps += [50, 100]
sample_steps += list(range(500, truncate_samples, 500))
sample_steps += [truncate_samples]
sample_steps = np.unique(sample_steps)
sample_steps = sample_steps[sample_steps <= truncate_samples]
sample_steps = sample_steps.tolist()
# calculate the rmse for each combination of samples and chains
lppd_over_samples_and_chains = []
for n_samples in sample_steps:
    for i, n_chains in enumerate(good_chains):
        inner = jnp.log(
            jnp.mean(
                ppd_pointwise_chain_dim[:, :n_samples, ...][good_chains[: i + 1], ...],
                axis=[0, 1],
            )
        )
        inner = inner[jnp.isfinite(inner)]
        lppd_over_samples_and_chains.append(
            jnp.nanmean(inner)
        )
# visualize the rmse over samples and chains using a heatmap
lppd_over_samples_and_chains = np.array(lppd_over_samples_and_chains).reshape(
    len(sample_steps), len(good_chains)
)
lppd_over_samples_and_chains = lppd_over_samples_and_chains[::-1, :]
fig = plt.figure(figsize=(10, 6))

sns.heatmap(
    lppd_over_samples_and_chains,
    xticklabels=[c for c in range(1, len(good_chains)+1)],
    yticklabels=sample_steps[::-1],
    cmap='YlGn',
)
# find the indices of the minimum value in the heatmap
min_idx = np.unravel_index(
    np.nanargmax(lppd_over_samples_and_chains),
    lppd_over_samples_and_chains.shape,
)
print(min_idx)
# annotate with a red cross
plt.scatter(min_idx[1]+0.5, min_idx[0]+0.5, marker='x', color='white')
plt.xlabel('Number of Chains')
plt.ylabel('Number of Samples (Non-Linear!)')
plt.title(
    (
        'LPPD over Samples and'
        ' Chains (Higher is better)'
    )
)
plt.close(fig)
fig

In [None]:
lppd_over_samples_and_chains_df = pd.DataFrame(
    lppd_over_samples_and_chains,
    index=sample_steps[::-1],
    columns=good_chains,
)
lppd_over_samples_and_chains_df.to_csv(
    f'../paper_bde/practical_sbi/chains_samples_grid/{DATASET}_lppd_over_samples_and_chains.csv'
)


### Calibration assessment

Again, save the data for better plotting in R.

In [None]:
for q in [0.25, 0.5, 0.75, 0.9, 0.98]:
    hpdi_preds = hpdi(preds[good_chains_pred_indices], q)
    acc_hpdi = jnp.mean(
        (hpdi_preds[0, :] <= Y_val.squeeze())
        & (hpdi_preds[1, :] >= Y_val.squeeze())
    )
    print(f'Accuracy of {int(q*100)}% HPDI: {acc_hpdi:.2f}')

In [None]:
fig = plt.figure(figsize=(10, 6))
colors = sns.color_palette('YlGn', 5)
quantiles =  np.linspace(0.1, 0.9, 9)
quantiles = np.concatenate([np.array([0.01, 0.25, 0.05]), quantiles, np.array([0.95, 0.99])])
coverage_df = pd.DataFrame()
# reverse the colors
colors = colors[::-1]
for trunc in [1, 10, 100, 1000, config["n_samples"]]:
    nominal_coverage_vals = []
    for q in quantiles:
        hpdi_preds = hpdi(preds_chain_dim[good_chains, :, :][:, :trunc, :].reshape(-1, preds.shape[1]), q)
        nominal_coverage_vals.append(jnp.mean(
            (hpdi_preds[0, :] <= Y_val.squeeze())
            & (hpdi_preds[1, :] >= Y_val.squeeze())
        ))
    coverage_df_temp = pd.DataFrame({
        "quantiles": quantiles,
        "empirical_coverage": nominal_coverage_vals,
        "truncation": trunc,
        "type": "samples"
    })
    coverage_df = pd.concat([coverage_df, coverage_df_temp])
    plt.plot(
        quantiles, 
        nominal_coverage_vals, 
        label=f'{trunc} samples',
        color=colors.pop(),
        marker='o',
        markerfacecolor='black',
    )
plt.legend()
plt.xlabel('Nominal Coverage Level')
plt.ylabel('Observed Coverage Level')
plt.title('Nominal Coverage of PP Credibility Intervals')
plt.ylim([0, 1])
plt.xlim([0, 1])
plt.gca().set_aspect('equal', adjustable='box')
plt.plot([0, 1], [0, 1], color='black', linestyle='--')
plt.fill_between([0, 1], [0, 0], [0, 1], color='red', alpha=0.2)
plt.annotate(
    'Overconfidence',
    xy=(0.75, 0.25),
    xytext=(0.75, 0.25),
    ha='center',
    va = 'center',
    color='red',
)
plt.close(fig)
fig

In [None]:
fig = plt.figure(figsize=(10, 6))
colors = sns.color_palette('YlGn', 5)
quantiles =  np.array([0.01, 0.025, 0.05, 0.1])
# reverse the colors
colors = colors[::-1]
for trunc in [1, 10, 100, 1000, config["n_samples"]]: 
    nominal_coverage_vals = []
    for q in quantiles:
        hpdi_preds = hpdi(preds_chain_dim[good_chains, :, :][:, :trunc, :].reshape(-1, preds.shape[1]), q)
        nominal_coverage_vals.append(jnp.mean(
            (hpdi_preds[0, :] <= Y_val.squeeze())
            & (hpdi_preds[1, :] >= Y_val.squeeze())
        ))
    plt.plot(
        quantiles, 
        nominal_coverage_vals, 
        label=f'{trunc} samples',
        color=colors.pop(),
        marker='o',
        markerfacecolor='black',
    )
plt.legend()
plt.xlabel('Nominal Coverage Level')
plt.ylabel('Observed Coverage Level')
plt.title('Nominal Coverage of PP Credibility Intervals')
plt.ylim([0, 0.11])
plt.xlim([0, 0.11])
plt.gca().set_aspect('equal', adjustable='box')
plt.plot([0, 0.11], [0, 0.11], color='black', linestyle='--')
plt.fill_between([0, 0.11], [0, 0], [0, 0.11], color='red', alpha=0.2)
plt.xticks(quantiles)
plt.annotate(
    'Overconfidence',
    xy=(0.08, 0.03),
    xytext=(0.08, 0.03),
    ha='center',
    va = 'center',
    color='red',
)
plt.close(fig)
fig

In [None]:
fig = plt.figure(figsize=(10, 6))
colors = sns.color_palette('YlGn', 5)
quantiles =  np.linspace(0.1, 0.9, 9)
quantiles = np.concatenate([np.array([0.01, 0.025, 0.05]), quantiles, np.array([0.95, 0.99])])
colors = colors[::-1]
trunc = 100
for chains in [1, 2, 4, 8, 10]: 
    nominal_coverage_vals = []
    for q in quantiles:
        hpdi_preds = hpdi(preds_chain_dim[good_chains, ...][:chains, :, :][:, :trunc, :].reshape(-1, preds.shape[1]), q)
        nominal_coverage_vals.append(jnp.mean(
            (hpdi_preds[0, :] <= Y_val.squeeze())
            & (hpdi_preds[1, :] >= Y_val.squeeze())
        ))
    coverage_df_temp = pd.DataFrame({
        "quantiles": quantiles,
        "empirical_coverage": nominal_coverage_vals,
        "truncation": chains,
        "type": f"chains_{trunc}samples"
    })
    coverage_df = pd.concat([coverage_df, coverage_df_temp])
    plt.plot(
        quantiles, 
        nominal_coverage_vals, 
        label=f'{chains} Chains',
        color=colors.pop(),
        marker='o',
        markerfacecolor='black',
    )
plt.legend()
plt.xlabel('Nominal Coverage Level')
plt.ylabel('Observed Coverage Level')
plt.title(f'Nominal Coverage of PP Credibility Intervals ({trunc} Samples)')
plt.ylim([0, 1])
plt.xlim([0, 1])
plt.gca().set_aspect('equal', adjustable='box')
plt.plot([0, 1], [0, 1], color='black', linestyle='--')
plt.fill_between([0, 1], [0, 0], [0, 1], color='red', alpha=0.2)
plt.annotate(
    'Overconfidence',
    xy=(0.75, 0.25),
    xytext=(0.75, 0.25),
    ha='center',
    va = 'center',
    color='red',
)
plt.close(fig)
fig

In [None]:
coverage_df.to_csv(
    f'../paper_bde/practical_sbi/calibration/{DATASET}_calibration.csv'
)

In [None]:
trunc_posterior_samples = {
    k: v[good_chains, :truncate_samples, ...] for k, v in posterior_samples.items()
}
interchain_means_normal = pp_interchain_means(
    trunc_posterior_samples, model, X_val
)
fig, ax = visualize_pp_chain_means(interchain_means_normal, 100, show=False)
ax.set_xticklabels([str(i) for i in good_chains])

# Convergence Diagnostics

### Prepare

In [None]:
good_chains_posterior_samples = {
    k: v[good_chains, ...] for k, v in posterior_samples.items()
}

In [None]:
all_params = {}
for param in good_chains_posterior_samples.keys():
    all_params[param] = calculate_diagnostics(
        good_chains_posterior_samples,
        param,
        truncate_samples,
    )

In [None]:
parameter_conv_diag_df = pd.DataFrame()

### Parameter Space

**ESS**


In [None]:
parameter_list = list(all_params.keys())
weight_parameters = [p for p in parameter_list if 'W' in p]
bias_parameters = [p for p in parameter_list if 'b' in p]

# calculate the average ESS for parameter layer
ess_per_layer = {}
bias_ess_per_layer = {}
sd_ess_per_layer = {}
sd_bias_ess_per_layer = {}
for layer in range(1, len(weight_parameters)+1):
    ess_per_layer[layer] = float(np.mean(all_params[f"W{layer}"][1]))
    bias_ess_per_layer[layer] = float(np.mean(all_params[f"b{layer}"][1]))
    sd_ess_per_layer[layer] = float(np.std(all_params[f"W{layer}"][1]))
    sd_bias_ess_per_layer[layer] = float(np.std(all_params[f"b{layer}"][1]))
    parameter_conv_diag_df = pd.concat(
        [
            parameter_conv_diag_df, 
            pd.DataFrame(
                {
                    "layer": [layer],
                    "parameter": ["weight"],
                    "metric": ["ess"],
                    "mean": [ess_per_layer[layer]],
                    "sd": [sd_ess_per_layer[layer]],
                }
            ),
            pd.DataFrame(
                {
                    "layer": [layer],
                    "parameter": ["bias"],
                    "metric": ["ess"],
                    "mean": [bias_ess_per_layer[layer]],
                    "sd": [sd_bias_ess_per_layer[layer]],
                }
            ),
        ]
    )
fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(
    x=list(ess_per_layer.keys()),
    y=list(ess_per_layer.values()),
    ax=ax,
    label='Weight',
    color='blue',
)
sns.lineplot(
    x=list(bias_ess_per_layer.keys()),
    y=list(bias_ess_per_layer.values()),
    ax=ax,
    label='Bias',
    color='orange',
)
ax.errorbar(
    list(ess_per_layer.keys()),
    list(ess_per_layer.values()),
    yerr=list(sd_ess_per_layer.values()),
    color='blue',
    fmt='o',
    capsize=5,
)
ax.errorbar(
    list(bias_ess_per_layer.keys()),
    list(bias_ess_per_layer.values()),
    yerr=list(sd_bias_ess_per_layer.values()),
    color='orange',
    fmt='o',
    capsize=5,
)
ax.set_ylim(bottom=0)
ax.set_xlabel('Parameter Layer')
ax.set_ylabel('Average ESS')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
plt.close(fig)
fig

$\hat R$


In [None]:
all_params_rhat_stats = {}
for param in all_params.keys():
    paramdict = all_params[param][0]
    for key in paramdict.keys():
        all_params_rhat_stats[key] = np.concatenate(
            [all_params_rhat_stats[key] , paramdict[key].reshape(-1)]
        ) if key in all_params_rhat_stats.keys() else paramdict[key].reshape(-1)

all_rhats = []
for param in all_params.keys():
    all_rhats.append(all_params[param][0]['rhat'].flatten())
all_rhats = np.concatenate(all_rhats).flatten()
# save to csv
pd.DataFrame(all_rhats).to_csv(
    f'../paper_bde/practical_sbi/convergence/{DATASET}_{parameter_conv_diag_df["layer"].max()}layers_rhat_params.csv'
)
visualize_rhat(all_params_rhat_stats)  

In [None]:
layerwise_params_rhat_stats_weights = {}
layerwise_params_rhat_stats_biases = {}
for layer in range(1, len(weight_parameters)+1):
    layerwise_params_rhat_stats_weights[layer] = {}
    layerwise_params_rhat_stats_biases[layer] = {}
    for param in [p for p in list(all_params.keys()) if (str(layer) in p) and ('W' in p)]:
        paramdict = all_params[param][0]
        for key in paramdict.keys():
            layerwise_params_rhat_stats_weights[layer][key] = np.concatenate(
                [layerwise_params_rhat_stats_weights[layer][key] , paramdict[key].reshape(-1)]
            ) if key in layerwise_params_rhat_stats_weights[layer].keys() else paramdict[key].reshape(-1)
    print(f"Layer {layer} - Weights")
    display(visualize_rhat(layerwise_params_rhat_stats_weights[layer]))
    for param in [p for p in list(all_params.keys()) if (str(layer) in p) and ('b' in p)]:
        paramdict = all_params[param][0]
        for key in paramdict.keys():
            layerwise_params_rhat_stats_biases[layer][key] = np.concatenate(
                [layerwise_params_rhat_stats_biases[layer][key] , paramdict[key].reshape(-1)]
            ) if key in layerwise_params_rhat_stats_biases[layer].keys() else paramdict[key].reshape(-1)
    print(f"Layer {layer} - Biases")
    display(visualize_rhat(layerwise_params_rhat_stats_biases[layer]))
    

Classical R-hat and chainwise R-hat

In [None]:
layerwise_rhat_w = [
    jnp.mean(layerwise_params_rhat_stats_weights[layer]["rhat"]).item() for layer in layerwise_params_rhat_stats_weights.keys()
]
layerwise_rhat_sd_w = [
    jnp.std(layerwise_params_rhat_stats_weights[layer]["rhat"]).item() for layer in layerwise_params_rhat_stats_weights.keys()
]
layerwise_split_rhat_w = [
    jnp.mean(layerwise_params_rhat_stats_weights[layer]["split_chain_rhat"]).item() for layer in layerwise_params_rhat_stats_weights.keys()
]
layerwise_split_rhat_sd_w = [
    jnp.std(layerwise_params_rhat_stats_weights[layer]["split_chain_rhat"]).item() for layer in layerwise_params_rhat_stats_weights.keys()
]
layerwise_rhat_b = [
    jnp.mean(layerwise_params_rhat_stats_biases[layer]["rhat"]).item() for layer in layerwise_params_rhat_stats_biases.keys()
]
layerwise_rhat_sd_b = [
    jnp.std(layerwise_params_rhat_stats_biases[layer]["rhat"]).item() for layer in layerwise_params_rhat_stats_biases.keys()
]
layerwise_split_rhat_b = [
    jnp.mean(layerwise_params_rhat_stats_biases[layer]["split_chain_rhat"]).item() for layer in layerwise_params_rhat_stats_biases.keys()
]
layerwise_split_rhat_sd_b = [
    jnp.std(layerwise_params_rhat_stats_biases[layer]["split_chain_rhat"]).item() for layer in layerwise_params_rhat_stats_biases.keys()
]
for layer in layerwise_params_rhat_stats_weights.keys():
    parameter_conv_diag_df = pd.concat(
        [
            parameter_conv_diag_df,
            pd.DataFrame(
                {
                    "layer": [layer],
                    "parameter": ["weight"],
                    "metric": ["rhat"],
                    "mean": [layerwise_rhat_w[layer-1]],
                    "sd": [layerwise_rhat_sd_w[layer-1]],
                }
            ),
            pd.DataFrame(
                {
                    "layer": [layer],
                    "parameter": ["bias"],
                    "metric": ["rhat"],
                    "mean": [layerwise_rhat_b[layer-1]],
                    "sd": [layerwise_rhat_sd_b[layer-1]],
                }
            ),
            pd.DataFrame(
                {
                    "layer": [layer],
                    "parameter": ["weight"],
                    "metric": ["split_chain_rhat"],
                    "mean": [layerwise_split_rhat_w[layer-1]],
                    "sd": [layerwise_split_rhat_sd_w[layer-1]],
                }
            ),
            pd.DataFrame(
                {
                    "layer": [layer],
                    "parameter": ["bias"],
                    "metric": ["split_chain_rhat"],
                    "mean": [layerwise_split_rhat_b[layer-1]],
                    "sd": [layerwise_split_rhat_sd_b[layer-1]],
                }
            ),
        ]
    )

fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(
    x=list(layerwise_params_rhat_stats_weights.keys()),
    y=layerwise_rhat_w,
    ax=ax,
    label='$\widehat{R}$ of Weights',
    color='#06238f',
)
sns.lineplot(
    x=list(layerwise_params_rhat_stats_biases.keys()),
    y=layerwise_rhat_b,
    ax=ax,
    label='$\widehat{R}$ of Biases',
    color='#2e59f2',
)
sns.lineplot(
    x=list(layerwise_params_rhat_stats_weights.keys()),
    y=layerwise_split_rhat_w,
    ax=ax,
    label='Chainwise $\widehat{R}$ of Weights',
    color='#035c0c',
    linestyle='--',
)
sns.lineplot(
    x=list(layerwise_params_rhat_stats_biases.keys()),
    y=layerwise_split_rhat_b,
    ax=ax,
    label='Chainwise $\widehat{R}$ of Biases',
    color='#26bf36',
    linestyle='--',
)
ax.errorbar(
    list(layerwise_params_rhat_stats_weights.keys()),
    layerwise_rhat_w,
    yerr=layerwise_rhat_sd_w,
    color='#06238f',
    fmt='o',
    capsize=5,
)
ax.errorbar(
    list(layerwise_params_rhat_stats_biases.keys()),
    layerwise_rhat_b,
    yerr=layerwise_rhat_sd_b,
    color='#2e59f2',
    fmt='o',
    capsize=5,
)
ax.errorbar(
    list(layerwise_params_rhat_stats_weights.keys()),
    layerwise_split_rhat_w,
    yerr=layerwise_split_rhat_sd_w,
    color='#035c0c',
    fmt='o',
    capsize=5,
)
ax.errorbar(
    list(layerwise_params_rhat_stats_biases.keys()),
    layerwise_split_rhat_b,
    yerr=layerwise_split_rhat_sd_b,
    color='#26bf36',
    fmt='o',
    capsize=5,
)
ax.set_ylim(bottom=1)
ax.set_xlabel('Hidden Layer')
ax.set_ylabel('')
plt.close(fig)
fig


In [None]:
# investigate the decomposition of the classic Rhat across layers (parameters) for both weights and biases (only use good chains)
between_chain_var_w = {}
within_chain_var_w = {}
between_chain_var_b = {}
within_chain_var_b = {}

for param in good_chains_posterior_samples.keys():
    if 'W' in param:
        between_chain_var_w[param] = good_chains_posterior_samples[param].mean(axis = 1).std(axis = 0)**2
        within_chain_var_w[param] = (good_chains_posterior_samples[param].std(axis = 1)**2).mean(axis = 0)
    elif 'b' in param:
        between_chain_var_b[param] = good_chains_posterior_samples[param].mean(axis = 1).std(axis = 0)
        within_chain_var_b[param] = good_chains_posterior_samples[param].std(axis = 1).mean(axis = 0)
between_chain_var_w_mean = {}
between_chain_var_w_std = {}
within_chain_var_w_mean = {}
within_chain_var_w_std = {}
between_chain_var_b_mean = {}
between_chain_var_b_std = {}
within_chain_var_b_mean = {}
within_chain_var_b_std = {}
for layer in range(1, len(weight_parameters)+1):
    between_chain_var_w_mean[layer] = np.mean(between_chain_var_w[f"W{layer}"]).item()
    between_chain_var_w_std[layer] = np.std(between_chain_var_w[f"W{layer}"]).item()
    within_chain_var_w_mean[layer] = np.mean(within_chain_var_w[f"W{layer}"]).item()
    within_chain_var_w_std[layer] = np.std(within_chain_var_w[f"W{layer}"]).item()
    between_chain_var_b_mean[layer] = np.mean(between_chain_var_b[f"b{layer}"]).item()
    between_chain_var_b_std[layer] = np.std(between_chain_var_b[f"b{layer}"]).item()
    within_chain_var_b_mean[layer] = np.mean(within_chain_var_b[f"b{layer}"]).item()
    within_chain_var_b_std[layer] = np.std(within_chain_var_b[f"b{layer}"]).item()
    parameter_conv_diag_df = pd.concat([
        parameter_conv_diag_df,
        pd.DataFrame(
            {
                "layer": [layer],
                "parameter": ["weight"],
                "metric": ["between_chain_var"],
                "mean": [between_chain_var_w_mean[layer]],
                "sd": [between_chain_var_w_std[layer]],
            }
        ),
        pd.DataFrame(
            {
                "layer": [layer],
                "parameter": ["bias"],
                "metric": ["between_chain_var"],
                "mean": [between_chain_var_b_mean[layer]],
                "sd": [between_chain_var_b_std[layer]],
            }
        ),
        pd.DataFrame(
            {
                "layer": [layer],
                "parameter": ["weight"],
                "metric": ["within_chain_var"],
                "mean": [within_chain_var_w_mean[layer]],
                "sd": [within_chain_var_w_std[layer]],
            }
        ),
        pd.DataFrame(
            {
                "layer": [layer],
                "parameter": ["bias"],
                "metric": ["within_chain_var"],
                "mean": [within_chain_var_b_mean[layer]],
                "sd": [within_chain_var_b_std[layer]],
            }
        ),
    ])

fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(
    x=list(between_chain_var_w_mean.keys()),
    y=list(between_chain_var_w_mean.values()),
    ax=ax,
    label='Weight',
    color='#06238f',
)
sns.lineplot(
    x=list(between_chain_var_b_mean.keys()),
    y=list(between_chain_var_b_mean.values()),
    ax=ax,
    label='Bias',
    color='#2e59f2',
)
ax.errorbar(
    list(between_chain_var_w_mean.keys()),
    list(between_chain_var_w_mean.values()),
    yerr=list(between_chain_var_w_std.values()),
    color='#06238f',
    fmt='o',
    capsize=5,
)
ax.errorbar(
    list(between_chain_var_b_mean.keys()),
    list(between_chain_var_b_mean.values()),
    yerr=list(between_chain_var_b_std.values()),
    color='#2e59f2',
    fmt='o',
    capsize=5,
)
ax.set_ylim(bottom=0)
ax.set_xlabel('Layer')
ax.set_ylabel('Avg. Between Chain Variance')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
plt.close(fig)
display(fig)

fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(
    x=list(within_chain_var_w_mean.keys()),
    y=list(within_chain_var_w_mean.values()),
    ax=ax,
    label='Weight',
    color='#06238f',
)
sns.lineplot(
    x=list(within_chain_var_b_mean.keys()),
    y=list(within_chain_var_b_mean.values()),
    ax=ax,
    label='Bias',
    color='#2e59f2',
)
ax.errorbar(
    list(within_chain_var_w_mean.keys()),
    list(within_chain_var_w_mean.values()),
    yerr=list(within_chain_var_w_std.values()),
    color='#06238f',
    fmt='o',
    capsize=5,
)
ax.errorbar(
    list(within_chain_var_b_mean.keys()),
    list(within_chain_var_b_mean.values()),
    yerr=list(within_chain_var_b_std.values()),
    color='#2e59f2',
    fmt='o',
    capsize=5,
)
ax.set_ylim(bottom=0)
ax.set_xlabel('Layer')
ax.set_ylabel('Avg. Within Chain Variance')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
plt.close(fig)
display(fig)

In [None]:
# investigate the absolut value of sampled weights and their standard deviation across layers (parameters) for both weights and biases (only use good chains)
abs_weight_samples = {}
abs_bias_samples = {}
for param in good_chains_posterior_samples.keys():
    if 'W' in param:
        abs_weight_samples[param] = np.abs(good_chains_posterior_samples[param])
    elif 'b' in param:
        abs_bias_samples[param] = np.abs(good_chains_posterior_samples[param])
abs_weight_samples_mean = {}
abs_weight_samples_std = {}
abs_bias_samples_mean = {}
abs_bias_samples_std = {}
for layer in range(1, len(weight_parameters)+1):
    abs_weight_samples_mean[layer] = np.mean(abs_weight_samples[f"W{layer}"])
    abs_weight_samples_std[layer] = np.std(abs_weight_samples[f"W{layer}"])
    abs_bias_samples_mean[layer] = np.mean(abs_bias_samples[f"b{layer}"])
    abs_bias_samples_std[layer] = np.std(abs_bias_samples[f"b{layer}"])
    parameter_conv_diag_df = pd.concat([
        parameter_conv_diag_df,
        pd.DataFrame(
            {
                "layer": [layer],
                "parameter": ["weight"],
                "metric": ["abs_samples_mean"],
                "mean": [abs_weight_samples_mean[layer]],
                "sd": [abs_weight_samples_std[layer]],
            }
        ),
        pd.DataFrame(
            {
                "layer": [layer],
                "parameter": ["bias"],
                "metric": ["abs_samples_mean"],
                "mean": [abs_bias_samples_mean[layer]],
                "sd": [abs_bias_samples_std[layer]],
            }
        ),
    ])
fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(
    x=list(abs_weight_samples_mean.keys()),
    y=list(abs_weight_samples_mean.values()),
    ax=ax,
    label='Weight',
    color='blue',
)
sns.lineplot(
    x=list(abs_bias_samples_mean.keys()),
    y=list(abs_bias_samples_mean.values()),
    ax=ax,
    label='Bias',
    color='orange',
)
ax.errorbar(
    list(abs_weight_samples_mean.keys()),
    list(abs_weight_samples_mean.values()),
    yerr=list(abs_weight_samples_std.values()),
    color='blue',
    fmt='o',
    capsize=5,
)
ax.errorbar(
    list(abs_bias_samples_mean.keys()),
    list(abs_bias_samples_mean.values()),
    yerr=list(abs_bias_samples_std.values()),
    color='orange',
    fmt='o',
    capsize=5,
)
ax.set_ylim(bottom=0)
ax.set_xlabel('Layer')
ax.set_ylabel('Average Absolute Value of Posterior (Parameter) Samples')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
plt.close(fig)
fig

As always save data to file for better plotting in R and further analysis.

In [None]:
parameter_conv_diag_df["data"] = DATASET
parameter_conv_diag_df["n_layers"] = parameter_conv_diag_df["layer"].max()
parameter_conv_diag_df.to_csv(
    f'../paper_bde/practical_sbi/convergence/{DATASET}_{parameter_conv_diag_df["layer"].max()}layers_parameter_conv.csv'
)

### Function Space

In [None]:
pp_split_chain_rhat = split_chain_r_hat(
    preds_chain_dim[good_chains, :truncate_samples, :],
    n_splits=4,
)
fig_pp_rhat = visualize_pp_rhat(pp_split_chain_rhat)
fig_pp_rhat

In [None]:
# save all function space rhats as above
pd.DataFrame(pp_split_chain_rhat["rhat"].flatten()).to_csv(
    f'../paper_bde/practical_sbi/convergence/{DATASET}_{parameter_conv_diag_df["layer"].max()}layers_rhat_pp.csv'
)

In [None]:
lppd_pointwise = model.get_lppd(X_val, Y_val, posterior_samples_raw, rolling=False)
print(lppd_pointwise.shape)

In [None]:
lppd_pointwise_chain_dim = add_chain_dimension({'lppd': lppd_pointwise}, n_chains=exp_info["n_chains"])['lppd']
# only use the good chains
lppd_pointwise_chain_dim = lppd_pointwise_chain_dim[good_chains, ...]
ppd_pointwise_chain_dim = jnp.exp(lppd_pointwise_chain_dim)

In [None]:
chainwise_lpl = jnp.log(jnp.expand_dims(ppd_pointwise_chain_dim.mean(axis=2), axis=2))
flat_lpl = chainwise_lpl[..., 0].reshape(-1)

In [None]:
chainwise_mean_lpl = chainwise_lpl[..., 0].mean(axis=1)
chainwise_sd_lpl = chainwise_lpl[..., 0].std(axis=1)
# visualize lineplot with error bars
fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(
    x=good_chains,
    y=chainwise_mean_lpl,
    ax=ax,
    color='blue',
)
ax.errorbar(
    good_chains,
    chainwise_mean_lpl,
    yerr=chainwise_sd_lpl,
    color='blue',
    fmt='o',
    capsize=5,
)
ax.set_xlabel('Chain')
ax.set_xticks(good_chains)
ax.set_ylabel('Chainwise LPL')
plt.close(fig)
fig

In [None]:
# visualize it as a colored traceplot
sns.set_theme(style='whitegrid')
fig, ax = plt.subplots(figsize=(10, 6))
chain_names = [f'chain_{i}' for i in range(len(good_chains))]
colors = np.repeat(chain_names, truncate_samples, axis=0)
function_space_df = pd.DataFrame({
    "sample": np.tile(np.arange(truncate_samples), len(good_chains))+1,
    "base_val": flat_lpl,
    "chain": colors,
    "type": "lpl",
    "rhat": gelman_split_r_hat(chainwise_lpl, n_splits=2, rank_normalize=True).item(),
})
sns.lineplot(
    x=np.arange(len(flat_lpl)),
    y=jnp.exp(flat_lpl),
    ax=ax,
    hue = colors
)
# ylim 0 1
# ax.set_ylim([0, 1])
ax.set_xlabel('Sample')
ax.set_ylabel('LPL')
ax.set_title('LPL')

In [None]:
gelman_split_r_hat(chainwise_lpl, n_splits=2, rank_normalize=True).item()

Now in a rolling window fashion

In [None]:
# iterate over the chains
chainwise_rolling_lppd = []
ppd_pointwise_chain_dim
for chain in good_chains:
    chain_cum_lppd = []
    for samp in range(ppd_pointwise_chain_dim.shape[1]):
        # only every 50th sample
        if samp % 20 != 0:
            continue
        inner = jnp.log(
            jnp.mean(
                ppd_pointwise_chain_dim[chain, :samp+1, ...],
                axis=0,
            )
        )
        inner = inner[jnp.isfinite(inner)]
        chain_cum_lppd.append(
            jnp.mean(inner)
        )
    chainwise_rolling_lppd.append(
        chain_cum_lppd
    )
chainwise_rolling_lppd = np.array(chainwise_rolling_lppd)
chainwise_rolling_lppd.shape

In [None]:
flat_rolling_lppd = flatten_chain_dimension({"clppd": chainwise_rolling_lppd})["clppd"]
sns.set_theme(style='whitegrid')
fig, ax = plt.subplots(figsize=(10, 6))
chain_names = [f'chain_{i}' for i in range(len(good_chains))]
colors = np.repeat(chain_names, chainwise_rolling_lppd.shape[1], axis=0)
function_space_df = pd.concat([
    function_space_df,
    pd.DataFrame({
        "sample": np.tile(np.arange(chainwise_rolling_lppd.shape[1]), len(good_chains))+1,
        "base_val": flat_rolling_lppd,
        "chain": colors,
        "type": "lppd",
        "rhat": np.nan,
    })
])
sns.lineplot(
    x=20*np.array(list(np.arange(1, chainwise_rolling_lppd.shape[1]+1)) * len(good_chains)),
    y=flat_rolling_lppd,
    ax=ax,
    hue = colors
)
ax.set_xlabel('Sample')
ax.set_ylabel('LPPD')

Now with the rmse instead of the lppd

In [None]:
def pointwise_rmse(preds, Y):
    return jnp.sqrt(jnp.mean((preds.squeeze() - Y.squeeze())**2))

mse_per_chain = jnp.apply_along_axis(pointwise_rmse, 2, preds_chain_dim, Y_val)
mse_per_chain = mse_per_chain[good_chains, ...]
mse_flat = mse_per_chain.reshape(-1)
sns.set_theme(style='whitegrid')
fig, ax = plt.subplots(figsize=(10, 6))
chain_names = [f'chain_{i}' for i in range(len(good_chains))]
colors = np.repeat(chain_names, truncate_samples, axis=0)
function_space_df = pd.concat([
    function_space_df,
    pd.DataFrame({
        "sample": np.tile(np.arange(truncate_samples), len(good_chains))+1,
        "base_val": mse_flat,
        "chain": colors,
        "type": "rmse",
        "rhat": gelman_split_r_hat(jnp.expand_dims(mse_per_chain[good_chains, ...], 2), n_splits=2, rank_normalize=True).item(),
    })
])
sns.lineplot(
    x=np.arange(len(mse_flat)),
    y=(mse_flat),
    ax=ax,
    hue = colors
)
ax.set_xlabel('Sample')
ax.set_ylabel('RMSE')
ax.set_title('RMSE of each Posterior Sample induced Model')


In [None]:
gelman_split_r_hat(jnp.expand_dims(mse_per_chain[good_chains, ...], 2), n_splits=2, rank_normalize=True).item()

In [None]:
function_space_df.to_csv(
    f'../paper_bde/practical_sbi/convergence/{DATASET}_{parameter_conv_diag_df["layer"].max()}layers_function_space.csv'
)

In [None]:
# calculate rmse for predictions with low and high and compare
# to rmse with low rhat
low_rhat = [i for i in range(preds_chain_dim.shape[2]) if pp_split_chain_rhat["rhat"][i] < 1.1]
print(f"Number of predictions with low Rhat: {len(low_rhat)}")
rhat_low_rmse = np.sqrt(
    mse(
        preds_chain_dim[good_chains, :truncate_samples, :][:, :, 
            low_rhat
        ],
        Y_val[low_rhat, :],
    )[0]
)
print("RMSE for predictions with low Rhat")
print(rhat_low_rmse)
high_rhat = [i for i in range(preds_chain_dim.shape[2]) if pp_split_chain_rhat["rhat"][i] >= 1.1]
print(f"Number of predictions with high Rhat: {len(high_rhat)}")
rhat_high_rmse = np.sqrt(
    mse(
        preds_chain_dim[good_chains, :truncate_samples, :][:, :, 
            high_rhat
        ],
        Y_val[high_rhat, :],
    )[0]
)
print("RMSE for predictions with high Rhat")
print(rhat_high_rmse)