In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_dir = f"datasets/friedman"
results_dir_relu = "results/regression/single_layer/relu/friedman"
results_dir_tanh = "results/regression/single_layer/tanh/friedman"
model_names_relu = ["Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]
model_names_tanh = ["Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh", "Beta Horseshoe tanh", "Beta Student T tanh"]

fits_relu = {}
fits_tanh = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"{base_config_name}"  # â†’ "type_1/GAM_N100_p8_sigma1.00_seed1"
    relu_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_relu,
        models=model_names_relu,
        include_prior=False,
    )
    tanh_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_tanh,
        models=model_names_tanh,
        include_prior=False,
    )

    fits_relu[base_config_name] = relu_fit 
    fits_tanh[base_config_name] = tanh_fit 

In [None]:
data_dir_correlated = f"datasets/friedman_correlated"
results_dir_correlated_relu = "results/regression/single_layer/relu/friedman_correlated"
results_dir_correlated_tanh = "results/regression/single_layer/tanh/friedman_correlated"
model_names_correlated_relu = ["Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]
model_names_correlated_tanh = ["Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh", "Beta Horseshoe tanh", "Beta Student T tanh"]

relu_fits_correlated = {}
tanh_fits_correlated = {}

files = sorted(f for f in os.listdir(data_dir_correlated) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"{base_config_name}"  # â†’ "type_1/GAM_N100_p8_sigma1.00_seed1"
    relu_fit_correlated = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_correlated_relu,
        models=model_names_correlated_relu,
        include_prior=False,
    )
    
    tanh_fit_correlated = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_correlated_tanh,
        models=model_names_correlated_tanh,
        include_prior=False,
    )
    

    relu_fits_correlated[base_config_name] = relu_fit_correlated
    tanh_fits_correlated[base_config_name] = tanh_fit_correlated
    


In [7]:
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error
from utils.generate_data import generate_Friedman_data, generate_correlated_Friedman_data
from utils.sparsity import forward_pass_tanh, forward_pass_relu
import numpy as np

def crps_from_samples(y_test, y_pred, return_distribution=False):
    """
    CRPS for posterior predictive samples.

    Parameters
    ----------
    y_test : array, shape (N,)
        True test targets
    y_pred : array, shape (S, N)
        Posterior predictive samples
    return_distribution : bool
        If True, also return CRPS per test point

    Returns
    -------
    crps_mean : float
        Mean CRPS across test points
    crps_se : float
        Standard error of CRPS across test points
    crps_per_point : array, optional, shape (N,)
        Returned if return_distribution=True
    """

    y_test = np.asarray(y_test)
    y_pred = np.asarray(y_pred)

    S, N = y_pred.shape
    assert y_test.shape[0] == N

    # E|Y - y|
    term1 = np.mean(np.abs(y_pred - y_test[None, :]), axis=0)

    # E|Y - Y'|
    # Efficient computation via sorting
    term2 = np.empty(N)
    for i in range(N):
        ys = np.sort(y_pred[:, i])
        term2[i] = np.mean(
            np.abs(ys[:, None] - ys[None, :])
        )

    crps_per_point = term1 - 0.5 * term2

    crps_mean = crps_per_point.mean()
    crps_se = crps_per_point.std(ddof=1) / np.sqrt(N)

    if return_distribution:
        return crps_mean, crps_se, crps_per_point

    return crps_mean, crps_se

def evaluate_posterior_on_multiple_testsets(
    fits,
    models,
    forward_pass,
    seeds,
    data_func = generate_Friedman_data
):
    rows = []

    for test_id in seeds:
        _, X_test, _, y_test = data_func(
            N=200, D=10, sigma=1.0, test_size=0.2, seed=test_id, standardize_y=True
        )

        X_test_np = X_test
        y_test_np = y_test.reshape(-1)

        for model in models:
            fit = fits[model]["posterior"]

            W1_samples = fit.stan_variable("W_1")
            W2_samples = fit.stan_variable("W_L")
            b1_samples = fit.stan_variable("hidden_bias")
            b2_samples = fit.stan_variable("output_bias")

            S = W1_samples.shape[0]
            y_hats = np.zeros((S, y_test_np.shape[0]))
            rmse = np.zeros((S))

            for i in range(S):
                y_hat = forward_pass(
                    X_test_np,
                    W1_samples[i],
                    np.asarray(b1_samples[i]).reshape(-1),
                    W2_samples[i],
                    np.asarray(b2_samples[i]).reshape(-1),
                )
                y_hats[i] = y_hat.squeeze()
                
                #rmse[i] = np.sqrt(mean_squared_error(y_test_np, y_hats[i]))
                
            #crps_mean, crps_std = crps_from_samples(y_test, y_hats)
            y_mean = y_hats.mean(axis=0)
            posterior_rmse = np.sqrt(mean_squared_error(y_test_np, y_mean))

            rows.append({
                "model": model,
                "test_set": test_id,
                "posterior_rmse": posterior_rmse,
                #"crps_mean": crps_mean,
                #"crps_std": crps_std
                #"mean_rmse": rmse.mean(axis=0)
            })

    df = pd.DataFrame(rows)

    # ðŸ”¹ THIS is the only new part
    df_rmse_mean = (
        df.groupby("model", as_index=False)["posterior_rmse"]
          .mean()
          .rename(columns={"posterior_rmse": "mean_rmse_over_testsets"})
    )
    df_crps_mean = pd.DataFrame(rows)
    # df_crps_mean = (
    #     df.groupby("model", as_index=False)
    #     .agg(
    #         mean_crps_over_testsets=("crps_mean", "mean"),
    #         se_crps_over_testsets=("crps_std",
    #                                 lambda x: np.sqrt(np.sum(x**2)) / len(x))
    #     )
    # )

    return df_rmse_mean, df_crps_mean, df

In [8]:
seeds = [100, 101, 102, 103, 104]

friedman_fits_smallest_relu = list((fits_relu['Friedman_N50_p10_sigma1.00_seed16']).keys())
friedman_fits_small_relu = list((fits_relu['Friedman_N100_p10_sigma1.00_seed1']).keys())
friedman_fits_medium_relu = list((fits_relu['Friedman_N200_p10_sigma1.00_seed2']).keys())
friedman_fits_large_relu = list((fits_relu['Friedman_N500_p10_sigma1.00_seed11']).keys())


df_results_smallest_relu, df_crps_smallest_relu, df_smallest_relu = evaluate_posterior_on_multiple_testsets(
    fits=fits_relu['Friedman_N50_p10_sigma1.00_seed16'],
    models=friedman_fits_smallest_relu,
    forward_pass=forward_pass_relu,
    seeds=seeds,
    data_func=generate_Friedman_data
)

df_results_small_relu, df_crps_small_relu, df_small_relu = evaluate_posterior_on_multiple_testsets(
    fits=fits_relu['Friedman_N100_p10_sigma1.00_seed1'],
    models=friedman_fits_small_relu,
    forward_pass=forward_pass_relu,
    seeds=seeds,
    data_func=generate_Friedman_data
)

df_results_medium_relu, df_crps_medium_relu, df_medium_relu = evaluate_posterior_on_multiple_testsets(
    fits=fits_relu['Friedman_N200_p10_sigma1.00_seed2'],
    models=friedman_fits_medium_relu,
    forward_pass=forward_pass_relu,
    seeds=seeds,
    data_func=generate_Friedman_data
)

df_results_large_relu, df_crps_large_relu, df_large_relu = evaluate_posterior_on_multiple_testsets(
    fits=fits_relu['Friedman_N500_p10_sigma1.00_seed11'],
    models=friedman_fits_large_relu,
    forward_pass=forward_pass_relu,
    seeds=seeds,
    data_func=generate_Friedman_data
)


friedman_fits_smallest_tanh = list((fits_tanh['Friedman_N50_p10_sigma1.00_seed16']).keys())
friedman_fits_small_tanh = list((fits_tanh['Friedman_N100_p10_sigma1.00_seed1']).keys())
friedman_fits_medium_tanh = list((fits_tanh['Friedman_N200_p10_sigma1.00_seed2']).keys())
friedman_fits_large_tanh = list((fits_tanh['Friedman_N500_p10_sigma1.00_seed11']).keys())


df_results_smallest_tanh, df_crps_smallest_tanh, df_smallest_tanh = evaluate_posterior_on_multiple_testsets(
    fits=fits_tanh['Friedman_N50_p10_sigma1.00_seed16'],
    models=friedman_fits_smallest_tanh,
    forward_pass=forward_pass_tanh,
    seeds=seeds,
    data_func=generate_Friedman_data
)

df_results_small_tanh, df_crps_small_tanh, df_small_tanh = evaluate_posterior_on_multiple_testsets(
    fits=fits_tanh['Friedman_N100_p10_sigma1.00_seed1'],
    models=friedman_fits_small_tanh,
    forward_pass=forward_pass_tanh,
    seeds=seeds,
    data_func=generate_Friedman_data
)

df_results_medium_tanh, df_crps_medium_tanh, df_medium_tanh = evaluate_posterior_on_multiple_testsets(
    fits=fits_tanh['Friedman_N200_p10_sigma1.00_seed2'],
    models=friedman_fits_medium_tanh,
    forward_pass=forward_pass_tanh,
    seeds=seeds,
    data_func=generate_Friedman_data
)

df_results_large_tanh, df_crps_large_tanh, df_large_tanh = evaluate_posterior_on_multiple_testsets(
    fits=fits_tanh['Friedman_N500_p10_sigma1.00_seed11'],
    models=friedman_fits_large_tanh,
    forward_pass=forward_pass_tanh,
    seeds=seeds,
    data_func=generate_Friedman_data
)


In [9]:
seeds = [100, 101, 102, 103, 104]

friedman_fits_smallest_relu_c = list((relu_fits_correlated['Friedman_N50_p10_sigma1.00_seed16']).keys())
friedman_fits_small_relu_c = list((relu_fits_correlated['Friedman_N100_p10_sigma1.00_seed1']).keys())
friedman_fits_medium_relu_c = list((relu_fits_correlated['Friedman_N200_p10_sigma1.00_seed6']).keys())
friedman_fits_large_relu_c = list((relu_fits_correlated['Friedman_N500_p10_sigma1.00_seed11']).keys())


df_results_smallest_relu_c, df_crps_smallest_relu_c, df_smallest_relu_c = evaluate_posterior_on_multiple_testsets(
    fits=relu_fits_correlated['Friedman_N50_p10_sigma1.00_seed16'],
    models=friedman_fits_smallest_relu_c,
    forward_pass=forward_pass_relu,
    seeds=seeds,
    data_func=generate_Friedman_data
)

df_results_small_relu_c, df_crps_small_relu_c, df_small_relu_c = evaluate_posterior_on_multiple_testsets(
    fits=relu_fits_correlated['Friedman_N100_p10_sigma1.00_seed1'],
    models=friedman_fits_small_relu_c,
    forward_pass=forward_pass_relu,
    seeds=seeds,
    data_func=generate_Friedman_data
)

df_results_medium_relu_c, df_crps_medium_relu_c, df_medium_relu_c = evaluate_posterior_on_multiple_testsets(
    fits=relu_fits_correlated['Friedman_N200_p10_sigma1.00_seed6'],
    models=friedman_fits_medium_relu_c,
    forward_pass=forward_pass_relu,
    seeds=seeds,
    data_func=generate_Friedman_data
)

df_results_large_relu_c, df_crps_large_relu_c, df_large_relu_c = evaluate_posterior_on_multiple_testsets(
    fits=relu_fits_correlated['Friedman_N500_p10_sigma1.00_seed11'],
    models=friedman_fits_large_relu_c,
    forward_pass=forward_pass_relu,
    seeds=seeds,
    data_func=generate_Friedman_data
)


friedman_fits_smallest_tanh_c = list((tanh_fits_correlated['Friedman_N50_p10_sigma1.00_seed16']).keys())
friedman_fits_small_tanh_c = list((tanh_fits_correlated['Friedman_N100_p10_sigma1.00_seed1']).keys())
friedman_fits_medium_tanh_c = list((tanh_fits_correlated['Friedman_N200_p10_sigma1.00_seed6']).keys())
friedman_fits_large_tanh_c = list((tanh_fits_correlated['Friedman_N500_p10_sigma1.00_seed11']).keys())


df_results_smallest_tanh_c, df_crps_smallest_tanh_c, df_smallest_tanh_c = evaluate_posterior_on_multiple_testsets(
    fits=tanh_fits_correlated['Friedman_N50_p10_sigma1.00_seed16'],
    models=friedman_fits_smallest_tanh_c,
    forward_pass=forward_pass_tanh,
    seeds=seeds,
    data_func=generate_Friedman_data
)

df_results_small_tanh_c, df_crps_small_tanh_c, df_small_tanh_c = evaluate_posterior_on_multiple_testsets(
    fits=tanh_fits_correlated['Friedman_N100_p10_sigma1.00_seed1'],
    models=friedman_fits_small_tanh_c,
    forward_pass=forward_pass_tanh,
    seeds=seeds,
    data_func=generate_Friedman_data
)

df_results_medium_tanh_c, df_crps_medium_tanh_c, df_medium_tanh_c = evaluate_posterior_on_multiple_testsets(
    fits=tanh_fits_correlated['Friedman_N200_p10_sigma1.00_seed6'],
    models=friedman_fits_medium_tanh_c,
    forward_pass=forward_pass_tanh,
    seeds=seeds,
    data_func=generate_Friedman_data
)

df_results_large_tanh_c, df_crps_large_tanh_c, df_large_tanh_c = evaluate_posterior_on_multiple_testsets(
    fits=tanh_fits_correlated['Friedman_N500_p10_sigma1.00_seed11'],
    models=friedman_fits_large_tanh_c,
    forward_pass=forward_pass_tanh,
    seeds=seeds,
    data_func=generate_Friedman_data
)


In [10]:
df_1 = df_results_smallest_relu.assign(activation = "relu", N=50, setting='Original')
df_2 = df_results_small_relu.assign(activation = "relu", N=100, setting='Original')
df_3 = df_results_medium_relu.assign(activation = "relu", N=200, setting='Original')
df_4 = df_results_large_relu.assign(activation = "relu", N=500, setting='Original')

df_5 = df_results_smallest_relu_c.assign(activation = "relu", N=50, setting='Correlated')
df_6 = df_results_small_relu_c.assign(activation = "relu", N=100, setting='Correlated')
df_7 = df_results_medium_relu_c.assign(activation = "relu", N=200, setting='Correlated')
df_8 = df_results_large_relu_c.assign(activation = "relu", N=500, setting='Correlated')

df_9 = df_results_smallest_tanh.assign(activation = "tanh", N=50, setting='Original')
df_10 = df_results_small_tanh.assign(activation = "tanh", N=100, setting='Original')
df_11 = df_results_medium_tanh.assign(activation = "tanh", N=200, setting='Original')
df_12 = df_results_large_tanh.assign(activation = "tanh", N=500, setting='Original')

df_13 = df_results_smallest_tanh_c.assign(activation = "tanh", N=50, setting='Correlated')
df_14 = df_results_small_tanh_c.assign(activation = "tanh", N=100, setting='Correlated')
df_15 = df_results_medium_tanh_c.assign(activation = "tanh", N=200, setting='Correlated')
df_16 = df_results_large_tanh_c.assign(activation = "tanh", N=500, setting='Correlated')

df_all = pd.concat([df_1, df_2, df_3, df_4, df_5, df_6, df_7, df_8,
                    df_9, df_10, df_11, df_12, df_13, df_14, df_15, df_16], ignore_index=True)
# df_all = pd.concat([df_1, df_2, df_5, df_6,
#                      df_9, df_10, df_13, df_14], ignore_index=True)


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# --- prepare data ---
df = df_all.copy()

abbr = {
    "Regularized Horseshoe": "RHS",
    "Dirichlet Horseshoe": "DHS",
    "Dirichlet Student T": "DST",
    "Beta Horseshoe": "BHS",
    "Beta Student T": "BST",
}

# unify model names across activations (strip " tanh")
df["model_clean"] = df["model"].str.replace(" tanh", "", regex=False)
#df["model_clean"] = df["model_clean"].str.replace(" nodewise", "", regex=False)
# summary stats per (setting, N, model, activation)
summary = (
    df.groupby(["setting", "N", "model_clean", "activation"], as_index=False)["mean_rmse_over_testsets"]
      .agg(mean="mean", std="std")
)

# plotting order
settings = ["Original", "Correlated"]
Ns = [50, 100, 200, 500]
models = ["Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]

# visuals
markers = {"tanh": "o", "relu": "X"}            # shapes
offsets = {"tanh": -0.12, "relu": +0.12}        # side-by-side jitter on x
model_offsets = {
    "Regularized Horseshoe": -0.07,
    "Dirichlet Horseshoe": -0.03,
    "Dirichlet Student T": 0.00,
    "Beta Horseshoe": +0.03,
    "Beta Student T": +0.07,
}
palette_list = plt.get_cmap("tab10").colors
palette = {m: palette_list[i+1] for i, m in enumerate(models)}

# map N to base x positions and add offsets for activation
xbase = {N: i for i, N in enumerate(Ns)}

fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)
ax_num = 0
for ax, setting in zip(axes, settings):
    sub = summary[summary["setting"] == setting]
    # plot each model+activation with errorbars, without lines
    for m in models:
        for act in ["tanh", "relu"]:
            g = sub[(sub["model_clean"] == m) & (sub["activation"] == act)]
            if g.empty:
                continue
            #xs = [xbase[n] + offsets[act] for n in g["N"]]
            xs = [xbase[n] + offsets[act] + model_offsets[m] for n in g["N"]]

            ax.errorbar(
                xs, g["mean"], yerr=g["std"],
                fmt=markers[act], markersize=12,
                linestyle="none", capsize=3,
                color=palette[m], markeredgecolor="black"
            )
    ax.set_title(f"{setting}", fontsize=15)
    ax.set_xticks(range(len(Ns)))
    ax.set_xticklabels(Ns, fontsize=15)
    ax.set_xlabel("N", fontsize=15)
    if ax_num == 0:
        ax.set_ylabel("RMSE", fontsize=15)
    ax.grid()
    ax_num += 1

# --- legends ---
model_handles = [
    Line2D(
        [0], [0],
        marker="o",
        linestyle="none",
        color=palette[m],
        markeredgecolor="black",
        markersize=12,
        label=abbr.get(m, m)   # <- use abbreviation
    )
    for m in models
]

# activation legend (shapes)
activation_handles = [
    Line2D([0], [0], marker=markers["tanh"], linestyle="none", color="black",
           markersize=12, label="tanh"),
    Line2D([0], [0], marker=markers["relu"], linestyle="none", color="black",
           markersize=12, label="ReLU"),
]

for ax in axes:
    ax.legend(
        handles=model_handles + activation_handles,
        title=None,
        loc="upper right",
        frameon=False,
        ncol=1,
        fontsize=15
    )
plt.tight_layout(rect=(0, 0, 1, 1))
#plt.grid()
plt.show()

In [9]:
df_crps_1 = df_crps_smallest_relu.assign(activation = "relu", N=50, setting='Original')
df_crps_2 = df_crps_small_relu.assign(activation = "relu", N=100, setting='Original')
# df_crps_3 = df_crps_medium_relu.assign(activation = "relu", N=200, setting='Original')
# df_crps_4 = df_crps_large_relu.assign(activation = "relu", N=500, setting='Original')

df_crps_5 = df_crps_smallest_relu_c.assign(activation = "relu", N=50, setting='Correlated')
df_crps_6 = df_crps_small_relu_c.assign(activation = "relu", N=100, setting='Correlated')
# df_crps_7 = df_crps_medium_relu_c.assign(activation = "relu", N=200, setting='Correlated')
# df_crps_8 = df_crps_large_relu_c.assign(activation = "relu", N=500, setting='Correlated')

df_crps_9 = df_crps_smallest_tanh.assign(activation = "tanh", N=50, setting='Original')
df_crps_10 = df_crps_small_tanh.assign(activation = "tanh", N=100, setting='Original')
# df_crps_11 = df_crps_medium_tanh.assign(activation = "tanh", N=200, setting='Original')
# df_crps_12 = df_crps_large_tanh.assign(activation = "tanh", N=500, setting='Original')

df_crps_13 = df_crps_smallest_tanh_c.assign(activation = "tanh", N=50, setting='Correlated')
df_crps_14 = df_crps_small_tanh_c.assign(activation = "tanh", N=100, setting='Correlated')
# df_crps_15 = df_crps_medium_tanh_c.assign(activation = "tanh", N=200, setting='Correlated')
# df_crps_16 = df_crps_large_tanh_c.assign(activation = "tanh", N=500, setting='Correlated')

# df_crps = pd.concat([df_crps_1, df_crps_2, df_crps_3, df_crps_4, df_crps_5, df_crps_6, df_crps_7, df_crps_8,
#                     df_crps_9, df_crps_10, df_crps_11, df_crps_12, df_crps_13, df_crps_14, df_crps_15, df_crps_16], ignore_index=True)

df_crps = pd.concat([df_crps_1, df_crps_2, df_crps_5, df_crps_6,
                     df_crps_9, df_crps_10, df_crps_13, df_crps_14], ignore_index=True)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# --- prepare data ---
df = df_crps.copy()

abbr = {
    "Regularized Horseshoe": "RHS",
    "Dirichlet Horseshoe": "DHS",
    "Dirichlet Student T": "DST",
    "Beta Horseshoe": "BHS",
    "Beta Student T": "BST",
}

# unify model names across activations (strip " tanh")
df["model_clean"] = df["model"].str.replace(" tanh", "", regex=False)
#df["model_clean"] = df["model_clean"].str.replace(" nodewise", "", regex=False)
# summary stats per (setting, N, model, activation)
# summary = (
#     df.groupby(["setting", "N", "model_clean", "activation"], as_index=False)["mean_crps_over_testsets"]
#       .agg(mean="mean", std="std")
# )

# plotting order
settings = ["Original", "Correlated"]
Ns = [50, 100, 200, 500]
models = ["Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]

# visuals
markers = {"tanh": "o", "relu": "X"}            # shapes
offsets = {"tanh": -0.12, "relu": +0.12}        # side-by-side jitter on x
model_offsets = {
    "Regularized Horseshoe": -0.07,
    "Dirichlet Horseshoe": -0.03,
    "Dirichlet Student T": 0.00,
    "Beta Horseshoe": +0.03,
    "Beta Student T": +0.07,
}
palette_list = plt.get_cmap("tab10").colors
palette = {m: palette_list[i+1] for i, m in enumerate(models)}

# map N to base x positions and add offsets for activation
xbase = {N: i for i, N in enumerate(Ns)}

fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)
ax_num = 0
for ax, setting in zip(axes, settings):
    sub = df[df["setting"] == setting]
    # plot each model+activation with errorbars, without lines
    for m in models:
        for act in ["tanh", "relu"]:
            g = sub[(sub["model_clean"] == m) & (sub["activation"] == act)]
            if g.empty:
                continue
            #xs = [xbase[n] + offsets[act] for n in g["N"]]
            xs = [xbase[n] + offsets[act] + model_offsets[m] for n in g["N"]]

            ax.errorbar(
                xs, g["mean_crps_over_testsets"], yerr=g["se_crps_over_testsets"],
                fmt=markers[act], markersize=12,
                linestyle="none", capsize=3,
                color=palette[m], markeredgecolor="black"
            )
    ax.set_title(f"{setting}", fontsize=15)
    ax.set_xticks(range(len(Ns)))
    ax.set_xticklabels(Ns, fontsize=15)
    ax.set_xlabel("N", fontsize=15)
    if ax_num == 0:
        ax.set_ylabel("CRPS", fontsize=15)
    ax.grid()
    ax_num += 1

# --- legends ---
model_handles = [
    Line2D(
        [0], [0],
        marker="o",
        linestyle="none",
        color=palette[m],
        markeredgecolor="black",
        markersize=12,
        label=abbr.get(m, m)   # <- use abbreviation
    )
    for m in models
]

# activation legend (shapes)
activation_handles = [
    Line2D([0], [0], marker=markers["tanh"], linestyle="none", color="black",
           markersize=12, label="tanh"),
    Line2D([0], [0], marker=markers["relu"], linestyle="none", color="black",
           markersize=12, label="ReLU"),
]

for ax in axes:
    ax.legend(
        handles=model_handles + activation_handles,
        title=None,
        loc="upper right",
        frameon=False,
        ncol=1,
        fontsize=15
    )
plt.tight_layout(rect=(0, 0, 1, 1))
#plt.grid()
plt.show()

## SPARSITY

In [12]:
from utils.sparsity import forward_pass_relu, forward_pass_tanh, local_prune_weights

def compute_sparse_rmse_results(seeds, models, all_fits, get_N_sigma, forward_pass, folder,
                         sparsity=0.0, prune_fn=None):
    results = []
    posterior_means = []

    for seed in seeds:
        N, sigma = get_N_sigma(seed)
        dataset_key = f'Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}'
        path = f"datasets/{folder}/{dataset_key}.npz"

        try:
            data = np.load(path)
            X_test, y_test = data["X_test"], data["y_test"]
        except FileNotFoundError:
            print(f"[SKIP] File not found: {path}")
            continue

        for model in models:
            try:
                fit = all_fits[dataset_key][model]['posterior']
                W1_samples = fit.stan_variable("W_1")           # (S, P, H)
                W2_samples = fit.stan_variable("W_L")           # (S, H, O)
                b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
                b2_samples = fit.stan_variable("output_bias")   # (S, O)
            except KeyError:
                print(f"[SKIP] Model or posterior not found: {dataset_key} -> {model}")
                continue

            S = W1_samples.shape[0]
            rmses = np.zeros(S)
            #print(y_test.shape)
            y_hats = np.zeros((S, y_test.shape[0]))

            for i in range(S):
                W1 = W1_samples[i]
                W2 = W2_samples[i]

                # Apply pruning mask if requested
                if prune_fn is not None and sparsity > 0.0:
                    masks = prune_fn([W1, W2], sparsity)
                    W1 = W1 * masks[0]
                    #W2 = W2 * masks[1]

                y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i])
                y_hats[i] = y_hat.squeeze()  # Store the prediction for each sample
                rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test)**2))
                
            posterior_mean = np.mean(y_hats, axis=0)
            posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test.squeeze())**2))

            posterior_means.append({
                'seed': seed,
                'N': N,
                'sigma': sigma,
                'model': model,
                'sparsity': sparsity,
                'posterior_mean_rmse': posterior_mean_rmse
            })

            for i in range(S):
                results.append({
                    'seed': seed,
                    'N': N,
                    'sigma': sigma,
                    'model': model,
                    'sparsity': sparsity,
                    'rmse': rmses[i]
                })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse


sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]

seeds = [1, 2, 11, 16]
seeds_correlated = [1, 6, 11, 16]

def get_N_sigma(seed):
    if seed == 1:
        N=100
    elif seed == 2:
        N=200
    elif seed == 16:
        N=50
    else:
        N=500
    sigma=1.00
    return N, sigma

def get_N_sigma_correlated(seed):
    if seed == 1:
        N=100
    elif seed == 6:
        N=200
    elif seed == 16:
        N=50
    else:
        N=500
    sigma=1.00
    return N, sigma

In [13]:
df_rmse_sparse, df_posterior_rmse_sparse = {}, {}
df_rmse_sparse_correlated, df_posterior_rmse_sparse_correlated = {}, {}

for sparsity in sparsity_levels:
    df_rmse_sparse[sparsity], df_posterior_rmse_sparse[sparsity] = compute_sparse_rmse_results(
        seeds, model_names_correlated_tanh, fits_tanh, get_N_sigma, forward_pass_tanh, folder = "friedman",
        sparsity=sparsity, prune_fn=local_prune_weights
    )
    
    df_rmse_sparse_correlated[sparsity], df_posterior_rmse_sparse_correlated[sparsity] = compute_sparse_rmse_results(
        seeds_correlated, model_names_correlated_tanh, tanh_fits_correlated, get_N_sigma_correlated, forward_pass_tanh, folder = "friedman_correlated",
        sparsity=sparsity, prune_fn=local_prune_weights
    )
    

In [14]:
import pandas as pd

df_rmse_full = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_posterior_rmse_sparse.items()],
    ignore_index=True
)

df_rmse_full_correlated = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_posterior_rmse_sparse_correlated.items()],
    ignore_index=True
)

df_local_o = df_rmse_full.copy()
df_local_o["model"] = df_local_o["model"].str.replace(" tanh", "", regex=False)

df_local_c = df_rmse_full_correlated.copy()
df_local_c["model"] = df_local_c["model"].str.replace(" tanh", "", regex=False)

In [15]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
import pandas as pd
from collections import OrderedDict

palette_list = plt.get_cmap("tab10").colors
palette = {m: palette_list[i+1] for i, m in enumerate(models)}

def make_merged_df(
    df_local_o, df_local_c,
    drop_tanh_suffix=True
):
    """Return one long df with columns: N, sparsity, rmse, model, activation, setting."""
    dfs = []
    for df, setting in [(df_local_o, "Original"), (df_local_c, "Correlated")]:
        d = df.copy()
        if drop_tanh_suffix and " tanh" in "".join(d["model"].unique()):
            d["model"] = d["model"].str.replace(" tanh", "", regex=False)
        d["activation"] = "Tanh"
        d["setting"] = setting
        dfs.append(d)
    out = pd.concat(dfs, ignore_index=True)

    # Keep only models that exist in BOTH activations so legend doesn't show ghosts
    models_local = set(out.loc[out.activation=="Tanh","model"].unique())
    if models_local:
        out = out[out["model"].isin(models_local)]
    return out

df_all = make_merged_df(df_local_o, df_local_c)


In [16]:
def plot_rmse_one_figure(
    df_all,
    Ns=(100, 200, 500),
    figsize=(12, 7),
    title="Original vs Correlated"
):
    # Orderings
    setting_order = ["Original", "Correlated"]

    sns.set_style("whitegrid")
    plt.rcParams.update({
        "axes.spines.top": False,
        "axes.spines.right": False,
        "legend.frameon": True,
    })

    fig, axes = plt.subplots(2, len(Ns), figsize=figsize, sharex=True, sharey="col")
    if len(Ns) == 1:
        axes = axes.reshape(2, 1)

    for j, Nval in enumerate(Ns):
        for i, setting in enumerate(setting_order):
            ax = axes[i, j]
            dfN = df_all[(df_all["N"] == Nval) & (df_all["setting"] == setting)].copy()

            # If empty, hide this subplot
            if dfN.empty:
                ax.set_visible(False)
                continue

            # Abbreviated labels for models
            dfN["model_abbr"] = dfN["model"].map(lambda m: abbr.get(m, m))

            # Build a palette keyed by the *abbreviated* model names
            color_map = {
                abbr[m]: palette[m]
                for m in dfN["model"].unique()
                if m in palette
            }

            hue_order = [
                abbr[m]
                for m in sorted(
                    dfN["model"].unique(),
                    key=lambda x: list(palette).index(x) if x in palette else 999
                )
            ]

            sns.lineplot(
                data=dfN,
                x="sparsity",
                y="posterior_mean_rmse",
                hue="model_abbr",       # color = prior (abbr)
                markers=True,
                dashes=True,           # single activation, keep lines simple
                palette=color_map,
                hue_order=hue_order,
                errorbar=None,
                ax=ax,
            )

            ax.set_title(f"N={Nval}", fontweight="normal")
            ax.set_xlabel("Sparsity")
            ax.set_ylabel("RMSE" if j == 0 else "")
            ax.grid(True, which="major", alpha=0.25)

            # Remove per-axes legends; weâ€™ll add one global legend
            if ax.legend_:
                ax.legend_.remove()

    # ---------- Global legend for priors (colors) ----------
    models_present = []
    for m in ["Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T",
              "Beta Horseshoe", "Beta Student T"]:
        if (df_all["model"] == m).any():
            models_present.append(m)

    prior_handles = [
        Line2D(
            [0], [0],
            color=palette[m],
            marker="o",
            linestyle="-",
            linewidth=2,
            markersize=7
        )
        for m in models_present
    ]
    prior_labels = [abbr[m] for m in models_present]

    if prior_handles:
        fig.legend(
            prior_handles,
            prior_labels,
            title="Prior",
            loc="upper center",
            ncol=len(prior_handles),
            frameon=True,
            bbox_to_anchor=(0.5, 1.02),
        )

    fig.suptitle(title, y=1.05)
    plt.tight_layout(rect=[0.02, 0.02, 0.98, 0.96])
    plt.show()


In [None]:

plot_rmse_one_figure(df_all,
                     Ns=(50, 100, 200, 500),
                     title="Original vs Correlated")


## SAMPLING

In [22]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import dirichlet, beta, cauchy

np.random.seed(1)

# dimensions
P = 10          # number of coefficients
S = 1_000      # prior samples

# hyperparameters
tau_scale = 1.0
alpha_dir = 0.1        # Dirichlet concentration
alpha_beta = 0.1      # Beta concentration

from scipy.stats import invgamma

def regularize_lambda(lambda_, tau, a=2.0, b=4.0, eps=1e-12):
    """
    Regularized horseshoe:
    lambda_tilde = c^2 * lambda^2 / (c^2 + tau^2 * lambda^2)
    """
    S, P = lambda_.shape
    c_sq = invgamma.rvs(a=a, scale=b, size=S)  # shape=S

    lambda_sq = lambda_**2
    tau_sq = tau[:, None]**2

    lambda_tilde = (
        c_sq[:, None] * lambda_sq
        / (c_sq[:, None] + tau_sq * lambda_sq)
    )

    return np.sqrt(np.maximum(lambda_tilde, eps))

def sample_weights_rhs(P, S):
    tau = np.abs(cauchy.rvs(scale=tau_scale, size=S))
    lambda_ = np.abs(cauchy.rvs(scale=1.0, size=(S, P)))
    lambda_reg = regularize_lambda(lambda_, tau)
    #xi_raw = beta.rvs(alpha_beta, (P-1)*alpha_beta, size=(S, P))
    #xi = xi_raw / xi_raw.sum(axis=1, keepdims=True)
    z = np.random.randn(S, P)
    w = tau[:, None] * lambda_reg * z
    return w

def sample_weights_dirichlet(P, S):
    tau = np.abs(cauchy.rvs(scale=tau_scale, size=S))
    lambda_ = np.abs(cauchy.rvs(scale=1.0, size=(S, P)))
    lambda_reg = regularize_lambda(lambda_, tau)
    xi = dirichlet.rvs([alpha_dir]*P, size=S)
    z = np.random.randn(S, P)
    w = tau[:, None] * lambda_reg * np.sqrt(xi) * z
    return w

def sample_weights_beta(P, S):
    tau = np.abs(cauchy.rvs(scale=tau_scale, size=S))
    lambda_ = np.abs(cauchy.rvs(scale=1.0, size=(S, P)))
    lambda_reg = regularize_lambda(lambda_, tau)
    xi_raw = beta.rvs(alpha_beta, (P-1)*alpha_beta, size=(S, P))
    xi = xi_raw / xi_raw.sum(axis=1, keepdims=True)
    z = np.random.randn(S, P)
    w = tau[:, None] * lambda_reg * np.sqrt(xi) * z
    return w


In [None]:
S = 50
w_dir = sample_weights_dirichlet(P, S)
w_beta = sample_weights_beta(P, S)
w_rhs = sample_weights_rhs(P, S)
bins = S/5
plt.figure(figsize=(7,5))
plt.hist(w_dir[:, 0], bins=int(bins), density=True, alpha=0.5, label="Dirichlet Î¾")
plt.hist(w_beta[:, 0], bins=int(bins), density=True, alpha=0.5, label="Beta Î¾")
#plt.hist(w_rhs[:, 0], bins=200, density=True, alpha=0.5, label="RHS")
#plt.xlim(-1,1)
plt.ylim(0, 3)
plt.legend()
plt.title("Marginal prior distribution of weights")
plt.show()


In [None]:
import numpy as np

def kl_js_from_hist(samples_p, samples_q, bins=200, range=(-1, 1), eps=1e-12):
    """
    Approximate KL(P||Q) and JS(P,Q) by discretizing both sample sets
    onto the same histogram bins over a fixed range.

    Returns:
        kl_pq, kl_qp, js
    """
    p_counts, bin_edges = np.histogram(samples_p, bins=bins, range=range, density=False)
    q_counts, _         = np.histogram(samples_q, bins=bins, range=range, density=False)

    # convert to probabilities (discrete)
    p = p_counts.astype(float)
    q = q_counts.astype(float)

    p = p / p.sum()
    q = q / q.sum()

    # smooth to avoid zeros (important for KL)
    p = np.clip(p, eps, None)
    q = np.clip(q, eps, None)
    p = p / p.sum()
    q = q / q.sum()

    # KL divergences
    kl_pq = np.sum(p * np.log(p / q))
    kl_qp = np.sum(q * np.log(q / p))

    # Jensenâ€“Shannon divergence
    m = 0.5 * (p + q)
    js = 0.5 * np.sum(p * np.log(p / m)) + 0.5 * np.sum(q * np.log(q / m))

    return kl_pq, kl_qp, js

# Use the same range you plotted
kl_dir_beta, kl_beta_dir, js = kl_js_from_hist(
    w_dir[:, 0], w_beta[:, 0],
    bins=200, range=(-1, 1), eps=1e-12
)

print(f"KL(Dir || Beta) over [-1,1]: {kl_dir_beta:.6g}")
print(f"KL(Beta || Dir) over [-1,1]: {kl_beta_dir:.6g}")
print(f"JS(Dir, Beta) over [-1,1]:   {js:.6g}")

# Use the same range you plotted
kl_dir_rhs, kl_rhs_dir, js = kl_js_from_hist(
    w_dir[:, 0], w_rhs[:, 0],
    bins=200, range=(-1, 1), eps=1e-12
)

print(f"KL(Dir || RHS) over [-1,1]: {kl_dir_rhs:.6g}")
print(f"KL(RHS || Dir) over [-1,1]: {kl_rhs_dir:.6g}")
print(f"JS(Dir, RHS) over [-1,1]:   {js:.6g}")

# Use the same range you plotted
kl_beta_rhs, kl_beta_dir, js = kl_js_from_hist(
    w_beta[:, 0], w_rhs[:, 0],
    bins=200, range=(-1, 1), eps=1e-12
)

print(f"KL(Beta || RHS) over [-1,1]: {kl_beta_rhs:.6g}")
print(f"KL(RHS || Beta) over [-1,1]: {kl_beta_dir:.6g}")
print(f"JS(Beta, RHS) over [-1,1]:   {js:.6g}")


## ABALONE

In [None]:
data_dir = f"datasets/abalone"
results_dir_relu = "results/regression/single_layer/relu/abalone"
results_dir_tanh = "results/regression/single_layer/tanh/abalone"

model_names_relu = ["Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]
model_names_tanh = ["Dirichlet Horseshoe tanh", "Dirichlet Student T tanh", "Beta Horseshoe tanh", "Beta Student T tanh"]


full_config_path = "abalone_N3341_p8"
abalone_relu_fit = get_model_fits(
    config=full_config_path,
    results_dir=results_dir_relu,
    models=model_names_relu,
    include_prior=False,
)

abalone_tanh_fit = get_model_fits(
    config=full_config_path,
    results_dir=results_dir_tanh,
    models=model_names_tanh,
    include_prior=False,
)

In [None]:
from sklearn.metrics import mean_squared_error
from properscoring import crps_ensemble
import numpy as np
import pandas as pd

# IMPORTANT: this y_test must correspond to the same test set used to make `output_test` in Stan,
# otherwise scores wonâ€™t be comparable.
from utils.generate_data import load_abalone_regression_data
X_train, X_test, y_train, y_test = load_abalone_regression_data(standardized=False, frac=1.0)

rows = []
for model_name, model_entry in abalone_relu_fit.items():
    post = model_entry["posterior"]

    # (S, n_test)
    y_samps = post.stan_variable("output_test").squeeze(-1)

    # Optional: limit to first S draws if desired
    # S = min(4000, y_samps.shape[0])
    # y_samps = y_samps[:S]

    # Posterior-mean predictions and RMSE
    y_mean = y_samps.mean(axis=0)                                   # (n_test,)
    rmse_post_mean = float(np.sqrt(mean_squared_error(y_test, y_mean)))

    # Per-draw RMSEs and their mean
    per_draw_rmse = np.sqrt(((y_samps - y_test[None, :])**2).mean(axis=1))  # (S,)
    rmse_draw_mean = float(per_draw_rmse.mean())

    # CRPS across the ensemble (expects shape (n_test, S))
    crps = float(np.mean(crps_ensemble(y_test, y_samps.T)))

    rows.append({
        "Model": model_name,
        "RMSE_posterior_mean": rmse_post_mean,
        "RMSE_mean_over_draws": rmse_draw_mean,
        "CRPS": crps,
        "n_draws": y_samps.shape[0]
    })

results_df = pd.DataFrame(rows).sort_values("RMSE_posterior_mean")
print(results_df)


In [None]:
from sklearn.metrics import mean_squared_error
from properscoring import crps_ensemble
import numpy as np
import pandas as pd

# IMPORTANT: this y_test must correspond to the same test set used to make `output_test` in Stan,
# otherwise scores wonâ€™t be comparable.
from utils.generate_data import load_abalone_regression_data
X_train, X_test, y_train, y_test = load_abalone_regression_data(standardized=False, frac=1.0)

rows = []
for model_name, model_entry in abalone_tanh_fit.items():
    post = model_entry["posterior"]

    # (S, n_test)
    y_samps = post.stan_variable("output_test").squeeze(-1)

    y_mean = y_samps.mean(axis=0)                                   # (n_test,)
    rmse_post_mean = float(np.sqrt(mean_squared_error(y_test, y_mean)))

    # Per-draw RMSEs and their mean
    per_draw_rmse = np.sqrt(((y_samps - y_test[None, :])**2).mean(axis=1))  # (S,)
    rmse_draw_mean = float(per_draw_rmse.mean())

    # CRPS across the ensemble (expects shape (n_test, S))
    crps = float(np.mean(crps_ensemble(y_test, y_samps.T)))

    rows.append({
        "Model": model_name,
        "RMSE_posterior_mean": rmse_post_mean,
        "RMSE_mean_over_draws": rmse_draw_mean,
        "CRPS": crps,
        "n_draws": y_samps.shape[0]
    })

results_df = pd.DataFrame(rows).sort_values("RMSE_posterior_mean")
print(results_df)


In [20]:
from utils.generate_data import load_abalone_regression_data
def compute_sparse_rmse_results_abalone(models, all_fits, forward_pass,
                         sparsity=0.0, prune_fn=None):
    results = []
    posterior_means = []
    for model in models:
        try:
            fit = all_fits[model]['posterior']
            W1_samples = fit.stan_variable("W_1")           # (S, P, H)
            W2_samples = fit.stan_variable("W_L")           # (S, H, O)
            b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
            b2_samples = fit.stan_variable("output_bias")   # (S, O)
        except KeyError:
            print(f"[SKIP] Model or posterior not found:")
            continue

        S = W1_samples.shape[0]
        rmses = np.zeros(S)
        #print(y_test.shape)
        _, X_test, _, y_test = load_abalone_regression_data(standardized=False, frac=1.0)
        y_hats = np.zeros((S, y_test.shape[0]))

        for i in range(S):
            W1 = W1_samples[i]
            W2 = W2_samples[i]

            # Apply pruning mask if requested
            if prune_fn is not None and sparsity > 0.0:
                masks = prune_fn([W1, W2], sparsity)
                W1 = W1 * masks[0]
                #W2 = W2 * masks[1]

            y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i])
            y_hats[i] = y_hat.squeeze()  # Store the prediction for each sample
            rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test)**2))
            
        posterior_mean = np.mean(y_hats, axis=0)
        posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test.squeeze())**2))

        posterior_means.append({
            'model': model,
            'sparsity': sparsity,
            'posterior_mean_rmse': posterior_mean_rmse
        })

        for i in range(S):
            results.append({
                'model': model,
                'sparsity': sparsity,
                'rmse': rmses[i]
            })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse


In [21]:
from utils.sparsity import forward_pass_relu, forward_pass_tanh, local_prune_weights

sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]

df_rmse_relu, df_posterior_rmse_relu = {}, {}
df_rmse_tanh, df_posterior_rmse_tanh = {}, {}

for sparsity in sparsity_levels:
    df_rmse_relu[sparsity], df_posterior_rmse_relu[sparsity] = compute_sparse_rmse_results_abalone(
        models = model_names_relu,
        all_fits = abalone_relu_fit, 
        forward_pass = forward_pass_relu,
        sparsity=sparsity, 
        prune_fn=local_prune_weights
    )

    df_rmse_tanh[sparsity], df_posterior_rmse_tanh[sparsity] = compute_sparse_rmse_results_abalone(
        models = model_names_tanh,
        all_fits = abalone_tanh_fit, 
        forward_pass = forward_pass_tanh,
        sparsity=sparsity, 
        prune_fn=local_prune_weights
    )


In [None]:
# Combine
df_rmse_full_relu = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_relu.items()],
    ignore_index=True
)

df_rmse_full_tanh = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_tanh.items()],
    ignore_index=True
)

# Plot (simplified version)
import matplotlib.pyplot as plt
import seaborn as sns
custom_palette = {
    "Dirichlet Horseshoe": "C2",
    "Dirichlet Student T": "C3",
    "Beta Horseshoe": "C4",
    "Beta Student T": "C5",
}
abbr = {
    "Dirichlet Horseshoe": "DHS",
    "Dirichlet Student T": "DST",
    "Beta Horseshoe": "BHS",
    "Beta Student T": "BST",
    #"Pred CP": "PCP"
}
# Clean names
df_rmse_full_relu["model"] = df_rmse_full_relu["model"].str.replace(" tanh", "", regex=False)
df_rmse_full_tanh["model"] = df_rmse_full_tanh["model"].str.replace(" tanh", "", regex=False)

fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True, sharey=True)
activation_data = [("ReLU", df_rmse_full_relu), ("tanh", df_rmse_full_tanh)]

for ax, (name, df) in zip(axes, activation_data):
    df["model_abbr"] = df["model"].map(lambda m: abbr.get(m, m))
    sns.lineplot(
        data=df,
        x='sparsity', y='rmse',
        hue='model_abbr', marker='o', errorbar=None, ax=ax,
        #palette=custom_palette,
        palette={abbr[k]: v for k, v in custom_palette.items() if k in df["model"].unique()},
        hue_order=[abbr[m] for m in sorted(df["model"].unique(), key=lambda x: list(custom_palette).index(x) if x in custom_palette else 999)],
    )
    
    ax.set_title(name)
    ax.set_xlabel("Sparsity level")
    ax.set_ylabel("RMSE")
    ax.grid(True)
    ax.legend(title="Model", loc="upper left")

plt.tight_layout()
plt.show()
