In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_dir = f"datasets/friedman"
results_dir = "results/regression/single_layer/tanh/friedman"
model_names = ["Dirichlet Horseshoe tanh", "Dirichlet Student T tanh", "Beta Horseshoe tanh", "Beta Student T tanh"]

fits = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"{base_config_name}"  # → "type_1/GAM_N100_p8_sigma1.00_seed1"
    fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir,
        models=model_names,
        include_prior=False,
    )
    
    

    fits[base_config_name] = fit 

In [None]:
data_dir_correlated = f"datasets/friedman_correlated"
results_dir_correlated = "results/regression/single_layer/tanh/friedman_correlated"
model_names_correlated = ["Dirichlet Horseshoe tanh", "Dirichlet Student T tanh", "Beta Horseshoe tanh", "Beta Student T tanh"]

fits_correlated = {}

files = sorted(f for f in os.listdir(data_dir_correlated) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"{base_config_name}"  # → "type_1/GAM_N100_p8_sigma1.00_seed1"
    fit_correlated = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_correlated,
        models=model_names_correlated,
        include_prior=False,
    )
    

    fits_correlated[base_config_name] = fit_correlated
    


In [4]:
import re
import numpy as np
import pandas as pd
from properscoring import crps_ensemble

_FRIEDMAN_KEY = re.compile(r"Friedman_N(\d+)_p\d+_sigma([\d.]+)_seed(\d+)")

def extract_friedman_metadata(key: str):
    """
    Parse 'Friedman_N{N}_p10_sigma{sigma}_seed{seed}' -> (N:int, sigma:float, seed:int)
    Returns (None, None, None) if it doesn't match.
    """
    m = _FRIEDMAN_KEY.search(key)
    if not m:
        return None, None, None
    N = int(m.group(1))
    sigma = float(m.group(2))
    seed = int(m.group(3))
    return N, sigma, seed


In [5]:
def compute_rmse_from_fits(all_fits, model_names=None, folder="friedman"):
    """
    Iterate over all dataset keys in `all_fits` (e.g., relu_fits or tanh_fits).
    For each model in `model_names` (or all models found if None), compute:
      - RMSE for each posterior draw
      - RMSE of the posterior mean predictor

    Returns:
        df_rmse: long DF with one row per posterior draw.
        df_posterior_rmse: one row per model/dataset with posterior-mean RMSE.
    """
    rmse_rows = []
    post_mean_rows = []

    for dataset_key, model_dict in all_fits.items():
        N, sigma, seed = extract_friedman_metadata(dataset_key)
        if N is None:
            # Skip non-Friedman entries if any
            continue

        
        try:
            path = f"datasets/{folder}/Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}.npz"
            data = np.load(path)
            y_test = data["y_test"].squeeze()  # shape (N_test,)
        except FileNotFoundError:
            path = f"datasets/{folder}/many/Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}.npz"
            data = np.load(path)
            y_test = data["y_test"].squeeze()  # shape (N_test,)
            #print(f"[SKIP] y_test not found: {path}")
            #continue

        # Choose which models to evaluate
        models_to_eval = model_names or list(model_dict.keys())

        for model in models_to_eval:
            # Some entries may be missing
            entry = model_dict.get(model, None)
            if not entry or "posterior" not in entry:
                print(f"[SKIP] Missing posterior: {dataset_key} -> {model}")
                continue

            fit = entry["posterior"]

            # Expecting (S, N_test, 1) or (S, N_test)
            output_test = fit.stan_variable("output_test")
            if output_test.ndim == 3 and output_test.shape[-1] == 1:
                preds = output_test[..., 0]  # (S, N_test)
            elif output_test.ndim == 2:
                preds = output_test  # (S, N_test)
            else:
                raise ValueError(f"Unexpected output_test shape {output_test.shape} for {dataset_key} -> {model}")

            # Per-sample RMSE
            sq_err = (preds - y_test[None, :])**2  # (S, N_test)
            rmse_per_sample = np.sqrt(np.mean(sq_err, axis=1))  # (S,)

            for s_idx, rmse in enumerate(rmse_per_sample):
                rmse_rows.append({
                    "dataset_key": dataset_key,
                    "model": model,
                    "N": N,
                    "sigma": sigma,
                    "seed": seed,
                    "sample_idx": s_idx,
                    "rmse": float(rmse)
                })

            # Posterior-mean RMSE
            posterior_mean = preds.mean(axis=0)  # (N_test,)
            post_mean_rmse = float(np.sqrt(np.mean((posterior_mean - y_test)**2)))
            post_mean_rows.append({
                "dataset_key": dataset_key,
                "model": model,
                "N": N,
                "sigma": sigma,
                "seed": seed,
                "posterior_mean_rmse": post_mean_rmse
            })

    df_rmse = pd.DataFrame(rmse_rows)
    df_posterior_rmse = pd.DataFrame(post_mean_rows)
    return df_rmse, df_posterior_rmse


def compute_crps_from_fits(all_fits, model_names=None):
    """
    Compute CRPS per dataset/model using all posterior predictive samples.

    Returns:
        df_crps: one row per dataset/model with mean CRPS.
    """
    rows = []

    for dataset_key, model_dict in all_fits.items():
        N, sigma, seed = extract_friedman_metadata(dataset_key)
        if N is None:
            continue

        path = f"datasets/friedman/Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}.npz"
        try:
            data = np.load(path)
            y_test = data["y_test"].squeeze()  # (N_test,)
        except FileNotFoundError:
            print(f"[SKIP] y_test not found: {path}")
            continue

        models_to_eval = model_names or list(model_dict.keys())

        for model in models_to_eval:
            entry = model_dict.get(model, None)
            if not entry or "posterior" not in entry:
                print(f"[SKIP] Missing posterior: {dataset_key} -> {model}")
                continue

            fit = entry["posterior"]
            output_test = fit.stan_variable("output_test")

            # Expecting (S, N_test, 1) or (S, N_test)
            if output_test.ndim == 3 and output_test.shape[-1] == 1:
                preds = output_test[..., 0]  # (S, N_test)
            elif output_test.ndim == 2:
                preds = output_test  # (S, N_test)
            else:
                raise ValueError(f"Unexpected output_test shape {output_test.shape} for {dataset_key} -> {model}")

            # crps_ensemble expects shape (N_test, S)
            crps_point = crps_ensemble(y_test, preds.T)  # (N_test,)
            rows.append({
                "dataset_key": dataset_key,
                "model": model,
                "N": N,
                "sigma": sigma,
                "seed": seed,
                "crps": float(crps_point.mean())
            })

    df_crps = pd.DataFrame(rows)
    return df_crps


In [None]:
# Evaluate models
df_rmse, df_posterior_rmse = compute_rmse_from_fits(
    fits, model_names
)

df_rmse_correlated, df_posterior_rmse_correlated = compute_rmse_from_fits(
    fits_correlated, model_names_correlated, folder="Friedman_correlated"  
)



In [None]:
summary = df_rmse.groupby(["model", "N"]).agg(
    acc_mean=("rmse", "mean"),
    acc_std=("rmse", "std"),
    #nll_mean=("nll", "mean"),
    #nll_std=("nll", "std"),
).reset_index()

summary_correlated = df_rmse_correlated.groupby(["model", "N"]).agg(
    acc_mean=("rmse", "mean"),
    acc_std=("rmse", "std"),
    #nll_mean=("nll", "mean"),
    #nll_std=("nll", "std"),
).reset_index()


print(summary.to_latex(index=False, float_format="%.3f"))
print(summary_correlated.to_latex(index=False, float_format="%.3f"))
# print(summary_tanh_correlated.to_latex(index=False, float_format="%.3f"))


In [7]:
import pandas as pd

df1 = df_rmse.assign(activation="local", setting="Original")
df2 = df_rmse_correlated.assign(activation="local", setting="Correlated")
dfpm = df_posterior_rmse.assign(activation="local", setting="Original")
dfpm_correlated = df_posterior_rmse_correlated.assign(activation="local", setting="Correlated")

df_all = pd.concat([df1, df2], ignore_index=True)
df_pm_all = pd.concat([dfpm, dfpm_correlated], ignore_index=True)


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# --- prepare data ---
df = df_all.copy()

abbr = {
    "Dirichlet Horseshoe": "DHS",
    "Dirichlet Student T": "DST",
    "Beta Horseshoe": "BHS",
    "Beta Student T": "BST",
}

# unify model names across activations (strip " tanh")
df["model_clean"] = df["model"].str.replace(" tanh", "", regex=False)
#df["model_clean"] = df["model_clean"].str.replace(" nodewise", "", regex=False)
# summary stats per (setting, N, model, activation)
summary = (
    df.groupby(["setting", "N", "model_clean", "activation"], as_index=False)["rmse"]
      .agg(mean="mean", std="std")
)

# plotting order
settings = ["Original", "Correlated"]
Ns = [50, 100, 200, 500]
models = ["Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]

# visuals
markers = {"local": "o", "nodewise": "X"}            # shapes
offsets = {"local": -0.05, "nodewise": +0.05}        # side-by-side jitter on x
model_offsets = {
    "Dirichlet Horseshoe": -0.06,
    "Beta Horseshoe": -0.02,
    "Dirichlet Student T": +0.02,
    "Beta Student T": +0.06,
}
palette_list = plt.get_cmap("tab10").colors
palette = {m: palette_list[i+1] for i, m in enumerate(models)}

# map N to base x positions and add offsets for activation
xbase = {N: i for i, N in enumerate(Ns)}

fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

for ax, setting in zip(axes, settings):
    sub = summary[summary["setting"] == setting]
    # plot each model+activation with errorbars, without lines
    for m in models:
        for act in ["local", "nodewise"]:
            g = sub[(sub["model_clean"] == m) & (sub["activation"] == act)]
            if g.empty:
                continue
            #xs = [xbase[n] + offsets[act] for n in g["N"]]
            xs = [xbase[n] + offsets[act] + model_offsets[m] for n in g["N"]]

            ax.errorbar(
                xs, g["mean"], yerr=g["std"],
                fmt=markers[act], markersize=7,
                linestyle="none", capsize=3,
                color=palette[m], markeredgecolor="black"
            )

    ax.set_title(f"{setting}")
    ax.set_xticks(range(len(Ns)))
    ax.set_xticklabels(Ns)
    ax.set_xlabel("N")
    ax.set_ylabel("RMSE")
    ax.grid()

# --- legends ---
model_handles = [
    Line2D(
        [0], [0],
        marker="o",
        linestyle="none",
        color=palette[m],
        markeredgecolor="black",
        markersize=7,
        label=abbr.get(m, m)   # <- use abbreviation
    )
    for m in models
]

# activation legend (shapes)
activation_handles = [
    Line2D([0], [0], marker=markers["local"], linestyle="none", color="black",
           markersize=7, label="Local"),
    Line2D([0], [0], marker=markers["nodewise"], linestyle="none", color="black",
           markersize=7, label="Nodewise"),
]

for ax in axes:
    ax.legend(
        handles=model_handles + activation_handles,
        title=None,
        loc="upper right",
        frameon=False,
        ncol=1
    )
plt.tight_layout(rect=(0, 0, 1, 1))
#plt.grid()
plt.show()

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# --- posterior mean RMSE df (rename for clarity) ---
df_pm = df_pm_all.copy()

# make model names consistent with your plotting code
df_pm["model_clean"] = df_pm["model"].str.replace(" tanh", "", regex=False)

# aggregate if multiple seeds per (setting,N,model); otherwise it’s already unique
pm_summary = (
    df_pm.groupby(["setting", "N", "model_clean"], as_index=False)["posterior_mean_rmse"]
        .mean()
)

abbr = {
    "Dirichlet Horseshoe": "DHS",
    "Dirichlet Student T": "DST",
    "Beta Horseshoe": "BHS",
    "Beta Student T": "BST",
}

# unify model names across activations (strip " tanh")
df_pm["model_clean"] = df_pm["model"].str.replace(" tanh", "", regex=False)
#df["model_clean"] = df["model_clean"].str.replace(" nodewise", "", regex=False)
# summary stats per (setting, N, model, activation)
summary = (
    df_pm.groupby(["setting", "N", "model_clean", "activation"], as_index=False)["posterior_mean_rmse"]
      .agg(mean="mean", std="std")
)

# plotting order
settings = ["Original", "Correlated"]
Ns = [50, 100, 200, 500]
models = ["Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]

# visuals
markers = {"local": "o", "nodewise": "X"}            # shapes
offsets = {"local": -0.05, "nodewise": +0.05}        # side-by-side jitter on x
model_offsets = {
    "Dirichlet Horseshoe": -0.06,
    "Dirichlet Student T": -0.02,
    "Beta Horseshoe": +0.02,
    "Beta Student T": +0.06,
}
palette_list = plt.get_cmap("tab10").colors
palette = {m: palette_list[i+1] for i, m in enumerate(models)}

# map N to base x positions and add offsets for activation
xbase = {N: i for i, N in enumerate(Ns)}

fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

for ax, setting in zip(axes, settings):
    sub = summary[summary["setting"] == setting]
    # plot each model+activation with errorbars, without lines
    for m in models:
        for act in ["local", "nodewise"]:
            g = sub[(sub["model_clean"] == m) & (sub["activation"] == act)]
            if g.empty:
                continue
            #xs = [xbase[n] + offsets[act] for n in g["N"]]
            xs = [xbase[n] + offsets[act] + model_offsets[m] for n in g["N"]]
            ax.plot(
                xs, g["mean"],
                marker=markers[act],
                markersize=7,
                linestyle="none",
                color=palette[m],
                markeredgecolor="black",
            )


    ax.set_title(f"{setting}")
    ax.set_xticks(range(len(Ns)))
    ax.set_xticklabels(Ns)
    ax.set_xlabel("N")
    ax.set_ylabel("RMSE")
    ax.grid()

# --- legends ---
model_handles = [
    Line2D(
        [0], [0],
        marker="o",
        linestyle="none",
        color=palette[m],
        markeredgecolor="black",
        markersize=7,
        label=abbr.get(m, m)   # <- use abbreviation
    )
    for m in models
]

# activation legend (shapes)
activation_handles = [
    Line2D([0], [0], marker=markers["local"], linestyle="none", color="black",
           markersize=7, label="Posterior RMSE"),
    #Line2D([0], [0], marker=markers["nodewise"], linestyle="none", color="black",
    #       markersize=7, label="Posterior RMSE"),
]

for ax in axes:
    ax.legend(
        handles=model_handles + activation_handles,
        title=None,
        loc="upper right",
        frameon=False,
        ncol=1
    )
plt.tight_layout(rect=(0, 0, 1, 1))
#plt.grid()
plt.show()

In [6]:
def compute_posterior_mean_rmse_aggregated(
    all_fits,
    model_names=None,
    folder="friedman",
    agg_fun="mean",  
):
    """
    Compute posterior-mean RMSE per dataset, then aggregate over datasets
    of the same type (same model, N, sigma).

    Returns:
        df_per_dataset: one row per (dataset_key, model)
        df_aggregated:  one row per (model, N, sigma)
    """
    post_mean_rows = []

    for dataset_key, model_dict in all_fits.items():
        N, sigma, seed = extract_friedman_metadata(dataset_key)
        if N is None:
            continue

        try:
            path = f"datasets/{folder}/Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}.npz"
            data = np.load(path)
        except FileNotFoundError:
            path = f"datasets/{folder}/many/Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}.npz"
            data = np.load(path)

        y_test = data["y_test"].squeeze()

        models_to_eval = model_names or list(model_dict.keys())

        for model in models_to_eval:
            entry = model_dict.get(model)
            if not entry or "posterior" not in entry:
                continue

            fit = entry["posterior"]
            output_test = fit.stan_variable("output_test")

            if output_test.ndim == 3 and output_test.shape[-1] == 1:
                preds = output_test[..., 0]
            elif output_test.ndim == 2:
                preds = output_test
            else:
                raise ValueError(f"Unexpected output_test shape {output_test.shape}")

            posterior_mean = preds.mean(axis=0)
            rmse = float(np.sqrt(np.mean((posterior_mean - y_test) ** 2)))

            post_mean_rows.append({
                "dataset_key": dataset_key,
                "model": model,
                "N": N,
                "sigma": sigma,
                "seed": seed,
                "posterior_mean_rmse": rmse,
            })

    df_per_dataset = pd.DataFrame(post_mean_rows)

    # ---- aggregation over datasets of same type ----
    group_cols = ["model", "N", "sigma"]

    if agg_fun == "mean":
        df_agg = (
            df_per_dataset
            .groupby(group_cols)
            .agg(
                rmse_mean=("posterior_mean_rmse", "mean"),
                rmse_sd=("posterior_mean_rmse", "std"),
                n_datasets=("posterior_mean_rmse", "size"),
            )
            .reset_index()
        )
    elif agg_fun == "median":
        df_agg = (
            df_per_dataset
            .groupby(group_cols)
            .agg(
                rmse_median=("posterior_mean_rmse", "median"),
                rmse_iqr=(
                    "posterior_mean_rmse",
                    lambda x: np.percentile(x, 75) - np.percentile(x, 25),
                ),
                n_datasets=("posterior_mean_rmse", "size"),
            )
            .reset_index()
        )
    else:
        raise ValueError("agg_fun must be 'mean' or 'median'")

    return df_per_dataset, df_agg


In [None]:
df_post_rmse, df_agg_post_rmse = compute_posterior_mean_rmse_aggregated(
    fits, model_names
)
df_agg_post_rmse

## SPARSITY

In [14]:
from utils.sparsity import forward_pass_relu, forward_pass_tanh, local_prune_weights

def compute_sparse_rmse_results(seeds, models, all_fits, get_N_sigma, forward_pass, folder,
                         sparsity=0.0, prune_fn=None):
    results = []
    posterior_means = []

    for seed in seeds:
        N, sigma = get_N_sigma(seed)
        dataset_key = f'Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}'
        path = f"datasets/{folder}/{dataset_key}.npz"

        try:
            data = np.load(path)
            X_test, y_test = data["X_test"], data["y_test"]
        except FileNotFoundError:
            print(f"[SKIP] File not found: {path}")
            continue

        for model in models:
            try:
                fit = all_fits[dataset_key][model]['posterior']
                W1_samples = fit.stan_variable("W_1")           # (S, P, H)
                W2_samples = fit.stan_variable("W_L")           # (S, H, O)
                b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
                b2_samples = fit.stan_variable("output_bias")   # (S, O)
            except KeyError:
                print(f"[SKIP] Model or posterior not found: {dataset_key} -> {model}")
                continue

            S = W1_samples.shape[0]
            rmses = np.zeros(S)
            #print(y_test.shape)
            y_hats = np.zeros((S, y_test.shape[0]))

            for i in range(S):
                W1 = W1_samples[i]
                W2 = W2_samples[i]

                # Apply pruning mask if requested
                if prune_fn is not None and sparsity > 0.0:
                    masks = prune_fn([W1, W2], sparsity)
                    W1 = W1 * masks[0]
                    #W2 = W2 * masks[1]

                y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i])
                y_hats[i] = y_hat.squeeze()  # Store the prediction for each sample
                rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test)**2))
                
            posterior_mean = np.mean(y_hats, axis=0)
            posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test.squeeze())**2))

            posterior_means.append({
                'seed': seed,
                'N': N,
                'sigma': sigma,
                'model': model,
                'sparsity': sparsity,
                'posterior_mean_rmse': posterior_mean_rmse
            })

            for i in range(S):
                results.append({
                    'seed': seed,
                    'N': N,
                    'sigma': sigma,
                    'model': model,
                    'sparsity': sparsity,
                    'rmse': rmses[i]
                })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse


sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]#, 0.9, 0.95]

seeds = [1, 2, 11, 16]
seeds_correlated = [1, 6, 11, 16]

def get_N_sigma(seed):
    if seed == 1:
        N=100
    elif seed == 2:
        N=200
    elif seed == 16:
        N=50
    else:
        N=500
    sigma=1.00
    return N, sigma

def get_N_sigma_correlated(seed):
    if seed == 1:
        N=100
    elif seed == 6:
        N=200
    elif seed == 16:
        N=50
    else:
        N=500
    sigma=1.00
    return N, sigma

In [15]:
df_rmse_sparse, df_posterior_rmse_sparse = {}, {}
df_rmse_sparse_correlated, df_posterior_rmse_sparse_correlated = {}, {}

for sparsity in sparsity_levels:
    df_rmse_sparse[sparsity], df_posterior_rmse_sparse[sparsity] = compute_sparse_rmse_results(
        seeds, model_names, fits, get_N_sigma, forward_pass_tanh, folder = "friedman",
        sparsity=sparsity, prune_fn=local_prune_weights
    )
    
    df_rmse_sparse_correlated[sparsity], df_posterior_rmse_sparse_correlated[sparsity] = compute_sparse_rmse_results(
        seeds_correlated, model_names_correlated, fits_correlated, get_N_sigma_correlated, forward_pass_tanh, folder = "friedman_correlated",
        sparsity=sparsity, prune_fn=local_prune_weights
    )
    

In [16]:
import pandas as pd

df_rmse_full = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_posterior_rmse_sparse.items()],
    ignore_index=True
)

df_rmse_full_correlated = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_posterior_rmse_sparse_correlated.items()],
    ignore_index=True
)

df_local_o = df_rmse_full.copy()
df_local_o["model"] = df_local_o["model"].str.replace(" tanh", "", regex=False)

df_local_c = df_rmse_full_correlated.copy()
df_local_c["model"] = df_local_c["model"].str.replace(" tanh", "", regex=False)

In [17]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
import pandas as pd
from collections import OrderedDict

palette_list = plt.get_cmap("tab10").colors
palette = {m: palette_list[i+1] for i, m in enumerate(models)}

def make_merged_df(
    df_local_o, df_local_c,
    drop_tanh_suffix=True
):
    """Return one long df with columns: N, sparsity, rmse, model, activation, setting."""
    dfs = []
    for df, setting in [(df_local_o, "Original"), (df_local_c, "Correlated")]:
        d = df.copy()
        if drop_tanh_suffix and " tanh" in "".join(d["model"].unique()):
            d["model"] = d["model"].str.replace(" tanh", "", regex=False)
        d["activation"] = "Tanh"
        d["setting"] = setting
        dfs.append(d)
    out = pd.concat(dfs, ignore_index=True)

    # Keep only models that exist in BOTH activations so legend doesn't show ghosts
    models_local = set(out.loc[out.activation=="Tanh","model"].unique())
    if models_local:
        out = out[out["model"].isin(models_local)]
    return out

df_all = make_merged_df(df_local_o, df_local_c)


In [20]:
def plot_rmse_one_figure(
    df_all,
    Ns=(100, 200, 500),
    figsize=(12, 7),
    title="Original vs Correlated"
):
    # Orderings
    setting_order = ["Original", "Correlated"]

    sns.set_style("whitegrid")
    plt.rcParams.update({
        "axes.spines.top": False,
        "axes.spines.right": False,
        "legend.frameon": True,
    })

    fig, axes = plt.subplots(2, len(Ns), figsize=figsize, sharex=True, sharey="col")
    if len(Ns) == 1:
        axes = axes.reshape(2, 1)

    for j, Nval in enumerate(Ns):
        for i, setting in enumerate(setting_order):
            ax = axes[i, j]
            dfN = df_all[(df_all["N"] == Nval) & (df_all["setting"] == setting)].copy()

            # If empty, hide this subplot
            if dfN.empty:
                ax.set_visible(False)
                continue

            # Abbreviated labels for models
            dfN["model_abbr"] = dfN["model"].map(lambda m: abbr.get(m, m))

            # Build a palette keyed by the *abbreviated* model names
            color_map = {
                abbr[m]: palette[m]
                for m in dfN["model"].unique()
                if m in palette
            }

            hue_order = [
                abbr[m]
                for m in sorted(
                    dfN["model"].unique(),
                    key=lambda x: list(palette).index(x) if x in palette else 999
                )
            ]

            sns.lineplot(
                data=dfN,
                x="sparsity",
                y="posterior_mean_rmse",
                hue="model_abbr",       # color = prior (abbr)
                markers=True,
                dashes=True,           # single activation, keep lines simple
                palette=color_map,
                hue_order=hue_order,
                errorbar=None,
                ax=ax,
            )

            ax.set_title(f"N={Nval}", fontweight="normal")
            ax.set_xlabel("Sparsity")
            ax.set_ylabel("RMSE" if j == 0 else "")
            ax.grid(True, which="major", alpha=0.25)

            # Remove per-axes legends; we’ll add one global legend
            if ax.legend_:
                ax.legend_.remove()

    # ---------- Global legend for priors (colors) ----------
    models_present = []
    for m in ["Dirichlet Horseshoe", "Dirichlet Student T",
              "Beta Horseshoe", "Beta Student T"]:
        if (df_all["model"] == m).any():
            models_present.append(m)

    prior_handles = [
        Line2D(
            [0], [0],
            color=palette[m],
            marker="o",
            linestyle="-",
            linewidth=2,
            markersize=7
        )
        for m in models_present
    ]
    prior_labels = [abbr[m] for m in models_present]

    if prior_handles:
        fig.legend(
            prior_handles,
            prior_labels,
            title="Prior",
            loc="upper center",
            ncol=len(prior_handles),
            frameon=True,
            bbox_to_anchor=(0.5, 1.02),
        )

    fig.suptitle(title, y=1.05)
    plt.tight_layout(rect=[0.02, 0.02, 0.98, 0.96])
    plt.show()


In [None]:

plot_rmse_one_figure(df_all,
                     Ns=(50, 100, 200),
                     title="Original vs Correlated")


## SAMPLING

In [22]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import dirichlet, beta, cauchy

np.random.seed(1)

# dimensions
P = 10          # number of coefficients
S = 1_000      # prior samples

# hyperparameters
tau_scale = 1.0
alpha_dir = 0.1        # Dirichlet concentration
alpha_beta = 0.1      # Beta concentration

from scipy.stats import invgamma

def regularize_lambda(lambda_, tau, a=2.0, b=4.0, eps=1e-12):
    """
    Regularized horseshoe:
    lambda_tilde = c^2 * lambda^2 / (c^2 + tau^2 * lambda^2)
    """
    S, P = lambda_.shape
    c_sq = invgamma.rvs(a=a, scale=b, size=S)  # shape=S

    lambda_sq = lambda_**2
    tau_sq = tau[:, None]**2

    lambda_tilde = (
        c_sq[:, None] * lambda_sq
        / (c_sq[:, None] + tau_sq * lambda_sq)
    )

    return np.sqrt(np.maximum(lambda_tilde, eps))

def sample_weights_rhs(P, S):
    tau = np.abs(cauchy.rvs(scale=tau_scale, size=S))
    lambda_ = np.abs(cauchy.rvs(scale=1.0, size=(S, P)))
    lambda_reg = regularize_lambda(lambda_, tau)
    #xi_raw = beta.rvs(alpha_beta, (P-1)*alpha_beta, size=(S, P))
    #xi = xi_raw / xi_raw.sum(axis=1, keepdims=True)
    z = np.random.randn(S, P)
    w = tau[:, None] * lambda_reg * z
    return w

def sample_weights_dirichlet(P, S):
    tau = np.abs(cauchy.rvs(scale=tau_scale, size=S))
    lambda_ = np.abs(cauchy.rvs(scale=1.0, size=(S, P)))
    lambda_reg = regularize_lambda(lambda_, tau)
    xi = dirichlet.rvs([alpha_dir]*P, size=S)
    z = np.random.randn(S, P)
    w = tau[:, None] * lambda_reg * np.sqrt(xi) * z
    return w

def sample_weights_beta(P, S):
    tau = np.abs(cauchy.rvs(scale=tau_scale, size=S))
    lambda_ = np.abs(cauchy.rvs(scale=1.0, size=(S, P)))
    lambda_reg = regularize_lambda(lambda_, tau)
    xi_raw = beta.rvs(alpha_beta, (P-1)*alpha_beta, size=(S, P))
    xi = xi_raw / xi_raw.sum(axis=1, keepdims=True)
    z = np.random.randn(S, P)
    w = tau[:, None] * lambda_reg * np.sqrt(xi) * z
    return w


In [None]:
S = 50
w_dir = sample_weights_dirichlet(P, S)
w_beta = sample_weights_beta(P, S)
w_rhs = sample_weights_rhs(P, S)
bins = S/5
plt.figure(figsize=(7,5))
plt.hist(w_dir[:, 0], bins=int(bins), density=True, alpha=0.5, label="Dirichlet ξ")
plt.hist(w_beta[:, 0], bins=int(bins), density=True, alpha=0.5, label="Beta ξ")
#plt.hist(w_rhs[:, 0], bins=200, density=True, alpha=0.5, label="RHS")
#plt.xlim(-1,1)
plt.ylim(0, 3)
plt.legend()
plt.title("Marginal prior distribution of weights")
plt.show()


In [None]:
import numpy as np

def kl_js_from_hist(samples_p, samples_q, bins=200, range=(-1, 1), eps=1e-12):
    """
    Approximate KL(P||Q) and JS(P,Q) by discretizing both sample sets
    onto the same histogram bins over a fixed range.

    Returns:
        kl_pq, kl_qp, js
    """
    p_counts, bin_edges = np.histogram(samples_p, bins=bins, range=range, density=False)
    q_counts, _         = np.histogram(samples_q, bins=bins, range=range, density=False)

    # convert to probabilities (discrete)
    p = p_counts.astype(float)
    q = q_counts.astype(float)

    p = p / p.sum()
    q = q / q.sum()

    # smooth to avoid zeros (important for KL)
    p = np.clip(p, eps, None)
    q = np.clip(q, eps, None)
    p = p / p.sum()
    q = q / q.sum()

    # KL divergences
    kl_pq = np.sum(p * np.log(p / q))
    kl_qp = np.sum(q * np.log(q / p))

    # Jensen–Shannon divergence
    m = 0.5 * (p + q)
    js = 0.5 * np.sum(p * np.log(p / m)) + 0.5 * np.sum(q * np.log(q / m))

    return kl_pq, kl_qp, js

# Use the same range you plotted
kl_dir_beta, kl_beta_dir, js = kl_js_from_hist(
    w_dir[:, 0], w_beta[:, 0],
    bins=200, range=(-1, 1), eps=1e-12
)

print(f"KL(Dir || Beta) over [-1,1]: {kl_dir_beta:.6g}")
print(f"KL(Beta || Dir) over [-1,1]: {kl_beta_dir:.6g}")
print(f"JS(Dir, Beta) over [-1,1]:   {js:.6g}")

# Use the same range you plotted
kl_dir_rhs, kl_rhs_dir, js = kl_js_from_hist(
    w_dir[:, 0], w_rhs[:, 0],
    bins=200, range=(-1, 1), eps=1e-12
)

print(f"KL(Dir || RHS) over [-1,1]: {kl_dir_rhs:.6g}")
print(f"KL(RHS || Dir) over [-1,1]: {kl_rhs_dir:.6g}")
print(f"JS(Dir, RHS) over [-1,1]:   {js:.6g}")

# Use the same range you plotted
kl_beta_rhs, kl_beta_dir, js = kl_js_from_hist(
    w_beta[:, 0], w_rhs[:, 0],
    bins=200, range=(-1, 1), eps=1e-12
)

print(f"KL(Beta || RHS) over [-1,1]: {kl_beta_rhs:.6g}")
print(f"KL(RHS || Beta) over [-1,1]: {kl_beta_dir:.6g}")
print(f"JS(Beta, RHS) over [-1,1]:   {js:.6g}")


## ABALONE

In [None]:
data_dir = f"datasets/abalone"
results_dir_relu = "results/regression/single_layer/relu/abalone"
results_dir_tanh = "results/regression/single_layer/tanh/abalone"

model_names_relu = ["Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]
model_names_tanh = ["Dirichlet Horseshoe tanh", "Dirichlet Student T tanh", "Beta Horseshoe tanh", "Beta Student T tanh"]


full_config_path = "abalone_N3341_p8"
abalone_relu_fit = get_model_fits(
    config=full_config_path,
    results_dir=results_dir_relu,
    models=model_names_relu,
    include_prior=False,
)

abalone_tanh_fit = get_model_fits(
    config=full_config_path,
    results_dir=results_dir_tanh,
    models=model_names_tanh,
    include_prior=False,
)

In [None]:
from sklearn.metrics import mean_squared_error
from properscoring import crps_ensemble
import numpy as np
import pandas as pd

# IMPORTANT: this y_test must correspond to the same test set used to make `output_test` in Stan,
# otherwise scores won’t be comparable.
from utils.generate_data import load_abalone_regression_data
X_train, X_test, y_train, y_test = load_abalone_regression_data(standardized=False, frac=1.0)

rows = []
for model_name, model_entry in abalone_relu_fit.items():
    post = model_entry["posterior"]

    # (S, n_test)
    y_samps = post.stan_variable("output_test").squeeze(-1)

    # Optional: limit to first S draws if desired
    # S = min(4000, y_samps.shape[0])
    # y_samps = y_samps[:S]

    # Posterior-mean predictions and RMSE
    y_mean = y_samps.mean(axis=0)                                   # (n_test,)
    rmse_post_mean = float(np.sqrt(mean_squared_error(y_test, y_mean)))

    # Per-draw RMSEs and their mean
    per_draw_rmse = np.sqrt(((y_samps - y_test[None, :])**2).mean(axis=1))  # (S,)
    rmse_draw_mean = float(per_draw_rmse.mean())

    # CRPS across the ensemble (expects shape (n_test, S))
    crps = float(np.mean(crps_ensemble(y_test, y_samps.T)))

    rows.append({
        "Model": model_name,
        "RMSE_posterior_mean": rmse_post_mean,
        "RMSE_mean_over_draws": rmse_draw_mean,
        "CRPS": crps,
        "n_draws": y_samps.shape[0]
    })

results_df = pd.DataFrame(rows).sort_values("RMSE_posterior_mean")
print(results_df)


In [None]:
from sklearn.metrics import mean_squared_error
from properscoring import crps_ensemble
import numpy as np
import pandas as pd

# IMPORTANT: this y_test must correspond to the same test set used to make `output_test` in Stan,
# otherwise scores won’t be comparable.
from utils.generate_data import load_abalone_regression_data
X_train, X_test, y_train, y_test = load_abalone_regression_data(standardized=False, frac=1.0)

rows = []
for model_name, model_entry in abalone_tanh_fit.items():
    post = model_entry["posterior"]

    # (S, n_test)
    y_samps = post.stan_variable("output_test").squeeze(-1)

    y_mean = y_samps.mean(axis=0)                                   # (n_test,)
    rmse_post_mean = float(np.sqrt(mean_squared_error(y_test, y_mean)))

    # Per-draw RMSEs and their mean
    per_draw_rmse = np.sqrt(((y_samps - y_test[None, :])**2).mean(axis=1))  # (S,)
    rmse_draw_mean = float(per_draw_rmse.mean())

    # CRPS across the ensemble (expects shape (n_test, S))
    crps = float(np.mean(crps_ensemble(y_test, y_samps.T)))

    rows.append({
        "Model": model_name,
        "RMSE_posterior_mean": rmse_post_mean,
        "RMSE_mean_over_draws": rmse_draw_mean,
        "CRPS": crps,
        "n_draws": y_samps.shape[0]
    })

results_df = pd.DataFrame(rows).sort_values("RMSE_posterior_mean")
print(results_df)


In [20]:
from utils.generate_data import load_abalone_regression_data
def compute_sparse_rmse_results_abalone(models, all_fits, forward_pass,
                         sparsity=0.0, prune_fn=None):
    results = []
    posterior_means = []
    for model in models:
        try:
            fit = all_fits[model]['posterior']
            W1_samples = fit.stan_variable("W_1")           # (S, P, H)
            W2_samples = fit.stan_variable("W_L")           # (S, H, O)
            b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
            b2_samples = fit.stan_variable("output_bias")   # (S, O)
        except KeyError:
            print(f"[SKIP] Model or posterior not found:")
            continue

        S = W1_samples.shape[0]
        rmses = np.zeros(S)
        #print(y_test.shape)
        _, X_test, _, y_test = load_abalone_regression_data(standardized=False, frac=1.0)
        y_hats = np.zeros((S, y_test.shape[0]))

        for i in range(S):
            W1 = W1_samples[i]
            W2 = W2_samples[i]

            # Apply pruning mask if requested
            if prune_fn is not None and sparsity > 0.0:
                masks = prune_fn([W1, W2], sparsity)
                W1 = W1 * masks[0]
                #W2 = W2 * masks[1]

            y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i])
            y_hats[i] = y_hat.squeeze()  # Store the prediction for each sample
            rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test)**2))
            
        posterior_mean = np.mean(y_hats, axis=0)
        posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test.squeeze())**2))

        posterior_means.append({
            'model': model,
            'sparsity': sparsity,
            'posterior_mean_rmse': posterior_mean_rmse
        })

        for i in range(S):
            results.append({
                'model': model,
                'sparsity': sparsity,
                'rmse': rmses[i]
            })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse


In [21]:
from utils.sparsity import forward_pass_relu, forward_pass_tanh, local_prune_weights

sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]

df_rmse_relu, df_posterior_rmse_relu = {}, {}
df_rmse_tanh, df_posterior_rmse_tanh = {}, {}

for sparsity in sparsity_levels:
    df_rmse_relu[sparsity], df_posterior_rmse_relu[sparsity] = compute_sparse_rmse_results_abalone(
        models = model_names_relu,
        all_fits = abalone_relu_fit, 
        forward_pass = forward_pass_relu,
        sparsity=sparsity, 
        prune_fn=local_prune_weights
    )

    df_rmse_tanh[sparsity], df_posterior_rmse_tanh[sparsity] = compute_sparse_rmse_results_abalone(
        models = model_names_tanh,
        all_fits = abalone_tanh_fit, 
        forward_pass = forward_pass_tanh,
        sparsity=sparsity, 
        prune_fn=local_prune_weights
    )


In [None]:
# Combine
df_rmse_full_relu = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_relu.items()],
    ignore_index=True
)

df_rmse_full_tanh = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_tanh.items()],
    ignore_index=True
)

# Plot (simplified version)
import matplotlib.pyplot as plt
import seaborn as sns
custom_palette = {
    "Dirichlet Horseshoe": "C2",
    "Dirichlet Student T": "C3",
    "Beta Horseshoe": "C4",
    "Beta Student T": "C5",
}
abbr = {
    "Dirichlet Horseshoe": "DHS",
    "Dirichlet Student T": "DST",
    "Beta Horseshoe": "BHS",
    "Beta Student T": "BST",
    #"Pred CP": "PCP"
}
# Clean names
df_rmse_full_relu["model"] = df_rmse_full_relu["model"].str.replace(" tanh", "", regex=False)
df_rmse_full_tanh["model"] = df_rmse_full_tanh["model"].str.replace(" tanh", "", regex=False)

fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True, sharey=True)
activation_data = [("ReLU", df_rmse_full_relu), ("tanh", df_rmse_full_tanh)]

for ax, (name, df) in zip(axes, activation_data):
    df["model_abbr"] = df["model"].map(lambda m: abbr.get(m, m))
    sns.lineplot(
        data=df,
        x='sparsity', y='rmse',
        hue='model_abbr', marker='o', errorbar=None, ax=ax,
        #palette=custom_palette,
        palette={abbr[k]: v for k, v in custom_palette.items() if k in df["model"].unique()},
        hue_order=[abbr[m] for m in sorted(df["model"].unique(), key=lambda x: list(custom_palette).index(x) if x in custom_palette else 999)],
    )
    
    ax.set_title(name)
    ax.set_xlabel("Sparsity level")
    ax.set_ylabel("RMSE")
    ax.grid(True)
    ax.legend(title="Model", loc="upper left")

plt.tight_layout()
plt.show()
