In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_dir = f"datasets/friedman_noisy"
results_dir_relu = "results/regression/single_layer/relu/friedman_noisy"
results_dir_tanh = "results/regression/single_layer/tanh/friedman_noisy"

# model_names_relu = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]
# model_names_tanh = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh", "Beta Horseshoe tanh", "Beta Student T tanh"]#, "Pred CP tanh"]
model_names_relu = ["Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]
model_names_tanh = ["Dirichlet Horseshoe tanh", "Dirichlet Student T tanh", "Beta Horseshoe tanh", "Beta Student T tanh"]#, "Pred CP tanh"]


relu_fits = {}
tanh_fits = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"{base_config_name}"  # → "type_1/GAM_N100_p8_sigma1.00_seed1"
    relu_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_relu,
        models=model_names_relu,
        include_prior=False,
    )
    
    tanh_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_tanh,
        models=model_names_tanh,
        include_prior=False,
    )
    

    relu_fits[base_config_name] = relu_fit  # use clean key
    tanh_fits[base_config_name] = tanh_fit  # use clean key
    


In [None]:
data_dir = f"datasets/friedman_noisy_correlated"
results_dir_relu_correlated = "results/regression/single_layer/relu/friedman_noisy_correlated"
results_dir_tanh_correlated = "results/regression/single_layer/tanh/friedman_noisy_correlated"

relu_fits_correlated = {}
tanh_fits_correlated = {}
files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"{base_config_name}"  # → "type_1/GAM_N100_p8_sigma1.00_seed1"
    relu_fit_correlated = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_relu_correlated,
        models=model_names_relu,
        include_prior=False,
    )
    
    tanh_fit_correlated = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_tanh_correlated,
        models=model_names_tanh,
        include_prior=False,
    )
    

    relu_fits_correlated[base_config_name] = relu_fit_correlated  # use clean key
    tanh_fits_correlated[base_config_name] = tanh_fit_correlated  # use clean key
    


In [4]:
import re
import numpy as np
import pandas as pd
from properscoring import crps_ensemble
from scores.probability import crps_for_ensemble

_FRIEDMAN_KEY = re.compile(r"Friedman_N(\d+)_p\d+_sigma([\d.]+)_seed(\d+)")

def extract_friedman_metadata(key: str):
    """
    Parse 'Friedman_N{N}_p10_sigma{sigma}_seed{seed}' -> (N:int, sigma:float, seed:int)
    Returns (None, None, None) if it doesn't match.
    """
    m = _FRIEDMAN_KEY.search(key)
    if not m:
        return None, None, None
    N = int(m.group(1))
    sigma = float(m.group(2))
    seed = int(m.group(3))
    return N, sigma, seed


In [5]:
def compute_rmse_from_fits(all_fits, model_names=None, folder="friedman_noisy"):
    """
    Iterate over all dataset keys in `all_fits` (e.g., relu_fits or tanh_fits).
    For each model in `model_names` (or all models found if None), compute:
      - RMSE for each posterior draw
      - RMSE of the posterior mean predictor

    Returns:
        df_rmse: long DF with one row per posterior draw.
        df_posterior_rmse: one row per model/dataset with posterior-mean RMSE.
    """
    rmse_rows = []
    post_mean_rows = []

    for dataset_key, model_dict in all_fits.items():
        N, sigma, seed = extract_friedman_metadata(dataset_key)
        if N is None:
            # Skip non-Friedman entries if any
            continue
        
        try:
            path = f"datasets/{folder}/Friedman_N{N}_p10_sigma{sigma:.1f}_seed{seed}.npz"
            data = np.load(path)
            y_test = data["y_test"].squeeze()  # shape (N_test,)
        except FileNotFoundError:
            path = f"datasets/{folder}/many/Friedman_N{N}_p10_sigma{sigma:.1f}_seed{seed}.npz"
            data = np.load(path)
            y_test = data["y_test"].squeeze()  # shape (N_test,)
            #print(f"[SKIP] y_test not found: {path}")
            #continue

        # Choose which models to evaluate
        models_to_eval = model_names or list(model_dict.keys())

        for model in models_to_eval:
            # Some entries may be missing
            entry = model_dict.get(model, None)
            if not entry or "posterior" not in entry:
                print(f"[SKIP] Missing posterior: {dataset_key} -> {model}")
                continue

            fit = entry["posterior"]

            # Expecting (S, N_test, 1) or (S, N_test)
            output_test = fit.stan_variable("output_test")
            if output_test.ndim == 3 and output_test.shape[-1] == 1:
                preds = output_test[..., 0]  # (S, N_test)
            elif output_test.ndim == 2:
                preds = output_test  # (S, N_test)
            else:
                raise ValueError(f"Unexpected output_test shape {output_test.shape} for {dataset_key} -> {model}")

            # Per-sample RMSE
            sq_err = (preds - y_test[None, :])**2  # (S, N_test)
            rmse_per_sample = np.sqrt(np.mean(sq_err, axis=1))  # (S,)

            for s_idx, rmse in enumerate(rmse_per_sample):
                rmse_rows.append({
                    "dataset_key": dataset_key,
                    "model": model,
                    "N": N,
                    "sigma": sigma,
                    "seed": seed,
                    "sample_idx": s_idx,
                    "rmse": float(rmse)
                })

            # Posterior-mean RMSE
            posterior_mean = preds.mean(axis=0)  # (N_test,)
            post_mean_rmse = float(np.sqrt(np.mean((posterior_mean - y_test)**2)))
            post_mean_rows.append({
                "dataset_key": dataset_key,
                "model": model,
                "N": N,
                "sigma": sigma,
                "seed": seed,
                "posterior_mean_rmse": post_mean_rmse
            })

    df_rmse = pd.DataFrame(rmse_rows)
    df_posterior_rmse = pd.DataFrame(post_mean_rows)
    return df_rmse, df_posterior_rmse


def compute_crps_from_fits(all_fits, model_names=None, folder="friedman_noisy"):
    """
    Compute CRPS per dataset/model using all posterior predictive samples.

    Returns:
        df_crps: one row per dataset/model with mean CRPS.
    """
    rows = []

    for dataset_key, model_dict in all_fits.items():
        N, sigma, seed = extract_friedman_metadata(dataset_key)
        if N is None:
            continue

        try:
            path = f"datasets/{folder}/Friedman_N{N}_p10_sigma{sigma:.1f}_seed{seed}.npz"
            data = np.load(path)
            y_test = data["y_test"].squeeze()  # shape (N_test,)
        except FileNotFoundError:
            path = f"datasets/{folder}/many/Friedman_N{N}_p10_sigma{sigma:.1f}_seed{seed}.npz"
            data = np.load(path)
            y_test = data["y_test"].squeeze()  # shape (N_test,)
            #print(f"[SKIP] y_test not found: {path}")
            #continue

        models_to_eval = model_names or list(model_dict.keys())

        for model in models_to_eval:
            entry = model_dict.get(model, None)
            if not entry or "posterior" not in entry:
                print(f"[SKIP] Missing posterior: {dataset_key} -> {model}")
                continue

            fit = entry["posterior"]
            output_test = fit.stan_variable("output_test")

            # Expecting (S, N_test, 1) or (S, N_test)
            if output_test.ndim == 3 and output_test.shape[-1] == 1:
                preds = output_test[..., 0]  # (S, N_test)
            elif output_test.ndim == 2:
                preds = output_test  # (S, N_test)
            else:
                raise ValueError(f"Unexpected output_test shape {output_test.shape} for {dataset_key} -> {model}")

            # crps_ensemble expects shape (N_test, S)
            crps_point = crps_ensemble(y_test, preds.T)  # (N_test,)
            rows.append({
                "dataset_key": dataset_key,
                "model": model,
                "N": N,
                "sigma": sigma,
                "seed": seed,
                "crps": float(crps_point.mean())
            })

    df_crps = pd.DataFrame(rows)
    return df_crps

import numpy as np
import pandas as pd
import xarray as xr

def compute_crps_from_fits_distributional(
    all_fits,
    model_names=None,
    folder="friedman_noisy",
    return_pointwise=True,
    add_obs_summaries=True,
):
    """
    Computes CRPS per test point using posterior predictive ensemble.
    Optionally returns pointwise CRPS (distribution across test points),
    plus per-dataset/model summaries (mean/median/quantiles over test points).

    Returns:
        df_pointwise (optional): one row per (dataset_key, model, test_idx) with crps_i
        df_summary: one row per (dataset_key, model) with summary stats over test points
    """
    rows_pointwise = []
    rows_summary = []

    for dataset_key, model_dict in all_fits.items():
        N, sigma, seed = extract_friedman_metadata(dataset_key)
        if N is None:
            continue

        # load y_test
        path1 = f"datasets/{folder}/Friedman_N{N}_p10_sigma{sigma:.1f}_seed{seed}.npz"
        path2 = f"datasets/{folder}/many/Friedman_N{N}_p10_sigma{sigma:.1f}_seed{seed}.npz"
        try:
            data = np.load(path1)
        except FileNotFoundError:
            data = np.load(path2)
        y_test = data["y_test"].squeeze()  # (N_test,)

        models_to_eval = model_names or list(model_dict.keys())

        for model in models_to_eval:
            entry = model_dict.get(model, None)
            if not entry or "posterior" not in entry:
                print(f"[SKIP] Missing posterior: {dataset_key} -> {model}")
                continue

            fit = entry["posterior"]
            output_test = fit.stan_variable("output_test")

            # (S, N_test, 1) or (S, N_test)
            if output_test.ndim == 3 and output_test.shape[-1] == 1:
                preds = output_test[..., 0]  # (S, N_test)
            elif output_test.ndim == 2:
                preds = output_test  # (S, N_test)
            else:
                raise ValueError(f"Unexpected output_test shape {output_test.shape} for {dataset_key} -> {model}")

            # per test point CRPS, using full posterior predictive ensemble
            #crps_i = crps_ensemble(y_test, preds.T)  # shape (N_test,)
            fcst = xr.DataArray(
                preds, 
                dims=["ensemble_member", "test_index"]
            )
            obs = xr.DataArray(
                y_test, 
                dims=["test_index"]
            )

            # compute CRPS for each test point
            crps_da = crps_for_ensemble(
                fcst,
                obs,
                ensemble_member_dim="ensemble_member",
                preserve_dims="test_index",
                method="fair"
            )

            crps_i = crps_da.values

            # store pointwise distribution across test points
            if return_pointwise:
                for i, val in enumerate(crps_i):
                    rows_pointwise.append({
                        "dataset_key": dataset_key,
                        "model": model,
                        "N": N,
                        "sigma": sigma,
                        "seed": seed,
                        "test_idx": i,
                        "y_true": float(y_test[i]),
                        "crps": float(val),
                    })

            # store summaries over test points (distributional summaries!)
            if add_obs_summaries:
                rows_summary.append({
                    "dataset_key": dataset_key,
                    "model": model,
                    "N": N,
                    "sigma": sigma,
                    "seed": seed,
                    "crps_mean": float(np.mean(crps_i)),
                    "crps_median": float(np.median(crps_i)),
                    "crps_q10": float(np.quantile(crps_i, 0.10)),
                    "crps_q90": float(np.quantile(crps_i, 0.90)),
                    "crps_std_over_test": float(np.std(crps_i, ddof=1)) if len(crps_i) > 1 else 0.0,
                })

    df_pointwise = pd.DataFrame(rows_pointwise) if return_pointwise else None
    df_summary = pd.DataFrame(rows_summary)

    return df_pointwise, df_summary


In [6]:
# Evaluate ReLU models
df_rmse_relu, df_posterior_rmse_relu = compute_rmse_from_fits(
    relu_fits, model_names_relu  # or None to use all found
)
# df_crps_relu = compute_crps_from_fits_distributional(
#     relu_fits, model_names_relu
# )

df_rmse_relu_correlated, df_posterior_rmse_relu_correlated = compute_rmse_from_fits(
    relu_fits_correlated, model_names_relu, folder = "friedman_noisy_correlated"
)
# df_crps_relu_correlated = compute_crps_from_fits_distributional(
#     relu_fits_correlated, model_names_relu, folder = "friedman_noisy_correlated"
# )

# Evaluate tanh models
df_rmse_tanh, df_posterior_rmse_tanh = compute_rmse_from_fits(
    tanh_fits, model_names_tanh
)
# df_crps_tanh = compute_crps_from_fits_distributional(
#     tanh_fits, model_names_tanh
# )

df_rmse_tanh_correlated, df_posterior_rmse_tanh_correlated = compute_rmse_from_fits(
    tanh_fits_correlated, model_names_tanh, folder = "friedman_noisy_correlated"
)
# df_crps_tanh_correlated = compute_crps_from_fits_distributional(
#     tanh_fits_correlated, model_names_tanh, folder = "friedman_noisy_correlated"
# )


In [7]:
import pandas as pd

df1 = df_rmse_relu.assign(activation="ReLU", setting="Original")
df2 = df_rmse_tanh.assign(activation="Tanh", setting="Original")
df3 = df_rmse_relu_correlated.assign(activation="ReLU", setting="Correlated")
df4 = df_rmse_tanh_correlated.assign(activation="Tanh", setting="Correlated")

df_all = pd.concat([df1, df2, df3, df4], ignore_index=True)


df1_pm = df_posterior_rmse_relu.assign(activation="ReLU", setting="Original")
df2_pm = df_posterior_rmse_tanh.assign(activation="Tanh", setting="Original")
df3_pm = df_posterior_rmse_relu_correlated.assign(activation="ReLU", setting="Correlated")
df4_pm = df_posterior_rmse_tanh_correlated.assign(activation="Tanh", setting="Correlated")

df_all_pm = pd.concat([df1_pm, df2_pm, df3_pm, df4_pm], ignore_index=True)


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# --- prepare data ---
df = df_all.copy()

abbr = {
    "Gaussian": "Gauss",
    "Regularized Horseshoe": "RHS",
    "Dirichlet Horseshoe": "DHS",
    "Dirichlet Student T": "DST",    
    "Beta Horseshoe": "BHS",
    "Beta Student T": "BST",
}

# unify model names across activations (strip " tanh")
df["model_clean"] = df["model"].str.replace(" tanh", "", regex=False)

# summary stats per (setting, N, model, activation)
summary = (
    df.groupby(["setting", "N", "model_clean", "activation"], as_index=False)["rmse"]
      .agg(mean="mean", std="std")
)

# plotting order
settings = ["Original", "Correlated"]
Ns = [50, 100, 200, 500]
models = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]

# visuals
markers = {"ReLU": "o", "Tanh": "^"}            # shapes
offsets = {"ReLU": -0.12, "Tanh": +0.12}        # side-by-side jitter on x
model_offsets = {
    "Gaussian": -0.08,
    "Regularized Horseshoe": -0.05,
    "Dirichlet Horseshoe": -0.02,
    "Dirichlet Student T": +0.02,
    "Beta Horseshoe": +0.05,
    "Beta Student T": +0.08,
}
palette_list = plt.get_cmap("tab10").colors
palette = {m: palette_list[i] for i, m in enumerate(models)}

# map N to base x positions and add offsets for activation
xbase = {N: i for i, N in enumerate(Ns)}

fig, axes = plt.subplots(1, 2, figsize=(12, 7), sharey=True)

for ax, setting in zip(axes, settings):
    sub = summary[summary["setting"] == setting]
    # plot each model+activation with errorbars, without lines
    for m in models:
        for act in ["ReLU", "Tanh"]:
            g = sub[(sub["model_clean"] == m) & (sub["activation"] == act)]
            if g.empty:
                continue
            #xs = [xbase[n] + offsets[act] for n in g["N"]]
            xs = [xbase[n] + offsets[act] + model_offsets[m] for n in g["N"]]

            ax.errorbar(
                xs, g["mean"], yerr=g["std"],
                fmt=markers[act], markersize=10,
                linestyle="none", capsize=3,
                color=palette[m], markeredgecolor="black"
            )

    ax.set_title(f"{setting}", fontsize=15)
    ax.set_xticks(range(len(Ns)))
    ax.set_xticklabels(Ns, fontsize=15)
    ax.set_xlabel("N", fontsize=15)
    ax.set_ylabel("RMSE", fontsize=15)
    ax.tick_params(axis='y', labelsize=15)
    ax.grid()

# --- legends ---
model_handles = [
    Line2D(
        [0], [0],
        marker="o",
        linestyle="none",
        color=palette[m],
        markeredgecolor="black",
        markersize=12,
        label=abbr.get(m, m)   # <- use abbreviation
    )
    for m in models
]

# activation legend (shapes)
activation_handles = [
    Line2D([0], [0], marker=markers["ReLU"], linestyle="none", color="black",
           markersize=12, label="ReLU"),
    Line2D([0], [0], marker=markers["Tanh"], linestyle="none", color="black",
           markersize=12, label="Tanh"),
]

for ax in axes:
    ax.legend(
        handles=model_handles + activation_handles,
        title=None,
        loc="upper right",
        frameon=False,
        ncol=1,
        fontsize = 14
    )
plt.tight_layout(rect=(0, 0, 1, 1))
#plt.grid()
plt.savefig("figures_for_use_in_paper/friedman_RMSE_with_beta.pdf", bbox_inches="tight")
plt.show()

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
# --- prepare data ---
df = df_all_pm.copy()

abbr = {
    "Gaussian": "Gauss",
    "Regularized Horseshoe": "RHS",
    "Dirichlet Horseshoe": "DHS",
    "Dirichlet Student T": "DST",
}

# unify model names across activations (strip " tanh")
df["model_clean"] = df["model"].str.replace(" tanh", "", regex=False)

# summary stats per (setting, N, model, activation)
summary = (
    df.groupby(["setting", "N", "model_clean", "activation"], as_index=False)["posterior_mean_rmse"]
      .agg(mean="mean", std="std")
)
# plotting order
settings = ["Original", "Correlated"]
Ns = [50, 100, 200, 500]
models = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]

# visuals
markers = {"ReLU": "o", "Tanh": "^"}            # shapes
offsets = {"ReLU": -0.12, "Tanh": +0.12}        # side-by-side jitter on x
model_offsets = {
    "Dirichlet Horseshoe": -0.06,
    #"Pred CP": -0.03,
    "Dirichlet Student T": -0.02,
    "Beta Horseshoe": +0.02,
    "Beta Student T": +0.06,
}
palette_list = plt.get_cmap("tab10").colors
palette = {m: palette_list[i] for i, m in enumerate(models)}

# map N to base x positions and add offsets for activation
xbase = {N: i for i, N in enumerate(Ns)}

fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

for ax, setting in zip(axes, settings):
    sub = summary[summary["setting"] == setting]
    # plot each model+activation with errorbars, without lines
    for m in models:
        for act in ["ReLU", "Tanh"]:
            g = sub[(sub["model_clean"] == m) & (sub["activation"] == act)]
            if g.empty:
                continue
            #xs = [xbase[n] + offsets[act] for n in g["N"]]
            xs = [xbase[n] + offsets[act] + model_offsets[m] for n in g["N"]]
            
            ax.plot(
                xs, g["mean"],
                marker=markers[act],
                markersize=7,
                linestyle="none",
                color=palette[m],
                markeredgecolor="black",
            )


    ax.set_title(f"{setting}")
    ax.set_xticks(range(len(Ns)))
    ax.set_xticklabels(Ns)
    ax.set_xlabel("N")
    ax.set_ylabel("RMSE")
    ax.grid()

# --- legends ---
model_handles = [
    Line2D(
        [0], [0],
        marker="o",
        linestyle="none",
        color=palette[m],
        markeredgecolor="black",
        markersize=7,
        label=abbr.get(m, m)   # <- use abbreviation
    )
    for m in models
]

# activation legend (shapes)
activation_handles = [
    Line2D([0], [0], marker=markers["ReLU"], linestyle="none", color="black",
           markersize=7, label="ReLU"),
    Line2D([0], [0], marker=markers["Tanh"], linestyle="none", color="black",
           markersize=7, label="Tanh"),
]

for ax in axes:
    ax.legend(
        handles=model_handles + activation_handles,
        title=None,
        loc="upper right",
        frameon=False,
        ncol=1
    )
plt.tight_layout(rect=(0, 0, 1, 1))
#plt.grid()
plt.show()

In [30]:
from utils.sparsity import forward_pass_relu, forward_pass_tanh, local_prune_weights

def compute_sparse_rmse_results(seeds, models, all_fits, get_N_sigma, forward_pass, folder,
                         sparsity=0.0, prune_fn=None):
    results = []
    posterior_means = []

    for seed in seeds:
        N, sigma = get_N_sigma(seed)
        dataset_key = f'Friedman_N{N}_p10_sigma{sigma:.1f}_seed{seed}'
        path = f"datasets/{folder}/{dataset_key}.npz"

        try:
            data = np.load(path)
            X_test, y_test = data["X_test"], data["y_test"]
        except FileNotFoundError:
            print(f"[SKIP] File not found: {path}")
            continue

        for model in models:
            try:
                fit = all_fits[dataset_key][model]['posterior']
                W1_samples = fit.stan_variable("W_1")           # (S, P, H)
                W2_samples = fit.stan_variable("W_L")           # (S, H, O)
                b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
                b2_samples = fit.stan_variable("output_bias")   # (S, O)
            except KeyError:
                print(f"[SKIP] Model or posterior not found: {dataset_key} -> {model}")
                continue

            S = W1_samples.shape[0]
            rmses = np.zeros(S)
            #print(y_test.shape)
            y_hats = np.zeros((S, y_test.shape[0]))

            for i in range(S):
                W1 = W1_samples[i]
                W2 = W2_samples[i]

                # Apply pruning mask if requested
                if prune_fn is not None and sparsity > 0.0:
                    masks = prune_fn([W1, W2], sparsity)
                    W1 = W1 * masks[0]
                    #W2 = W2 * masks[1]

                y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i])
                y_hats[i] = y_hat.squeeze()  # Store the prediction for each sample
                rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test)**2))
                
            posterior_mean = np.mean(y_hats, axis=0)
            posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test.squeeze())**2))

            posterior_means.append({
                'seed': seed,
                'N': N,
                'sigma': sigma,
                'model': model,
                'sparsity': sparsity,
                'posterior_mean_rmse': posterior_mean_rmse
            })

            for i in range(S):
                results.append({
                    'seed': seed,
                    'N': N,
                    'sigma': sigma,
                    'model': model,
                    'sparsity': sparsity,
                    'rmse': rmses[i]
                })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse


sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]

seeds = [1, 2, 11]
seeds_correlated = [1, 6, 11]

def get_N_sigma(seed):
    if seed == 1:
        N=100
    elif seed == 2:
        N=200
    else:
        N=500
    sigma=10.0
    return N, sigma

def get_N_sigma_correlated(seed):
    if seed == 1:
        N=100
    elif seed == 6:
        N=200
    else:
        N=500
    sigma=10.0
    return N, sigma

In [None]:
df_rmse_relu, df_posterior_rmse_relu = {}, {}
df_rmse_relu_correlated, df_posterior_rmse_relu_correlated = {}, {}
df_rmse_tanh, df_posterior_rmse_tanh = {}, {}
df_rmse_tanh_correlated, df_posterior_rmse_tanh_correlated = {}, {}

for sparsity in sparsity_levels:
    df_rmse_relu[sparsity], df_posterior_rmse_relu[sparsity] = compute_sparse_rmse_results(
        seeds, model_names_relu, relu_fits, get_N_sigma, forward_pass_relu, folder = "friedman_noisy",
        sparsity=sparsity, prune_fn=local_prune_weights
    )
    
    df_rmse_relu_correlated[sparsity], df_posterior_rmse_relu_correlated[sparsity] = compute_sparse_rmse_results(
        seeds_correlated, model_names_relu, relu_fits_correlated, get_N_sigma_correlated, forward_pass_relu, folder = "friedman_noisy_correlated",
        sparsity=sparsity, prune_fn=local_prune_weights
    )
    
    df_rmse_tanh[sparsity], df_posterior_rmse_tanh[sparsity] = compute_sparse_rmse_results(
        seeds, model_names_tanh, tanh_fits, get_N_sigma, forward_pass_tanh, folder = "friedman_noisy",
        sparsity=sparsity, prune_fn=local_prune_weights
    )
    
    df_rmse_tanh_correlated[sparsity], df_posterior_rmse_tanh_correlated[sparsity] = compute_sparse_rmse_results(
        seeds_correlated, model_names_tanh, tanh_fits_correlated, get_N_sigma_correlated, forward_pass_tanh, folder = "friedman_noisy_correlated",
        sparsity=sparsity, prune_fn=local_prune_weights
    )

In [32]:
import pandas as pd

df_rmse_full_relu = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_relu.items()],
    ignore_index=True
)

df_rmse_full_relu_correlated = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_relu_correlated.items()],
    ignore_index=True
)

df_rmse_full_tanh = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_tanh.items()],
    ignore_index=True
)

df_rmse_full_tanh_correlated = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_tanh_correlated.items()],
    ignore_index=True
)


In [33]:
df_tanh_o = df_rmse_full_tanh.copy()
df_tanh_o["model"] = df_tanh_o["model"].str.replace(" tanh", "", regex=False)

df_tanh_c = df_rmse_full_tanh_correlated.copy()
df_tanh_c["model"] = df_tanh_c["model"].str.replace(" tanh", "", regex=False)

df_relu_o = df_rmse_full_relu.copy()
df_relu_c = df_rmse_full_relu_correlated.copy()


In [34]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
import pandas as pd
from collections import OrderedDict

# --- Your palettes and abbreviations ---
palette = {
    "Dirichlet Horseshoe": "C2",
    "Dirichlet Student T": "C3",
    "Beta Horseshoe": "C4",
    "Beta Student T": "C5",
}
abbr = {
    "Dirichlet Horseshoe": "DHS",
    "Dirichlet Student T": "DST",
    "Beta Horseshoe": "BHS",
    "Beta Student T": "BST",
}

def make_merged_df(
    df_tanh_o, df_tanh_c, df_relu_o, df_relu_c,
    drop_tanh_suffix=True
):
    """Return one long df with columns: N, sparsity, rmse, model, activation, setting."""
    dfs = []
    # Tanh (optionally strip ' tanh' from model names if present)
    for df, setting in [(df_tanh_o, "Original"), (df_tanh_c, "Correlated")]:
        d = df.copy()
        if drop_tanh_suffix and " tanh" in "".join(d["model"].unique()):
            d["model"] = d["model"].str.replace(" tanh", "", regex=False)
        d["activation"] = "tanh"
        d["setting"] = setting
        dfs.append(d)
    # ReLU
    for df, setting in [(df_relu_o, "Original"), (df_relu_c, "Correlated")]:
        d = df.copy()
        d["activation"] = "ReLU"
        d["setting"] = setting
        dfs.append(d)
    out = pd.concat(dfs, ignore_index=True)

    # Keep only models that exist in BOTH activations so legend doesn't show ghosts
    models_tanh = set(out.loc[out.activation=="tanh","model"].unique())
    models_relu = set(out.loc[out.activation=="ReLU","model"].unique())
    common_models = sorted(list(models_tanh & models_relu))
    if common_models:
        out = out[out["model"].isin(common_models)]
    return out

df_all = make_merged_df(df_tanh_o, df_tanh_c, df_relu_o, df_relu_c)



In [35]:
def plot_rmse_one_figure(
    df_all,
    Ns=(100, 200, 500), figsize=(12, 12), title="Original vs Correlated (tanh vs ReLU)"
):
    

    # Orderings
    setting_order = ["Original", "Correlated"]
    activation_order = ["tanh", "ReLU"]

    # Seaborn aesthetics (keeps your 'talk' sizing / whitegrid)
    #sns.set_context("talk")
    sns.set_style("whitegrid")
    plt.rcParams.update({
        "axes.spines.top": False,
        "axes.spines.right": False,
        #"axes.titleweight": "semibold",
        "legend.frameon": True
    })

    fig, axes = plt.subplots(2, len(Ns), figsize=figsize, sharex=True, sharey=False)
    if len(Ns) == 1:
        axes = axes.reshape(2, 1)

    # We’ll plot using seaborn’s style mapping (style=activation, markers=True, dashes=True)
    # so tanh vs ReLU are visually distinct and consistent across the grid.
    for j, Nval in enumerate(Ns):
        for i, setting in enumerate(setting_order):
            ax = axes[i, j]
            dfN = df_all[(df_all["N"] == Nval) & (df_all["setting"] == setting)].copy()
            # Safety: if empty, skip
            if dfN.empty:
                ax.set_visible(False)
                continue

            # Use abbreviated labels on the legend (we’ll build custom legends later anyway)
            dfN["model_abbr"] = dfN["model"].map(lambda m: abbr.get(m, m))

            sns.lineplot(
                data=dfN,
                x="sparsity",
                y="rmse",
                hue="model_abbr",      # color = prior (abbr)
                style="activation",    # style = activation
                markers=True,
                dashes=True,
                palette={abbr[k]: v for k, v in palette.items() if k in dfN["model"].unique()},
                hue_order=[abbr[m] for m in sorted(dfN["model"].unique(), key=lambda x: list(palette).index(x) if x in palette else 999)],
                style_order=activation_order,
                errorbar=None,
                ax=ax,
            )
            #ax.set_title(f"N={Nval}")
            ax.set_title(f"N={Nval}", fontweight="normal", fontsize=15)

            ax.set_xlabel("Sparsity", fontsize=15)
            ax.set_ylabel("RMSE" if j == 0 else "", fontsize=15)
            ax.tick_params(axis='both', labelsize=10)
            ax.grid(True, which="major", alpha=0.25)
            if ax.legend_:  # remove local legends
                ax.legend_.remove()

    # ---------- Build two clean, global legends ----------
    # 1) Prior legend (colors), using abbreviations in desired order present in data
    models_present = []
    for m in ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]:
        if (df_all["model"] == m).any():
            models_present.append(m)
    prior_handles = [
        Line2D([0],[0], color=palette[m], marker='o', linestyle='-', linewidth=2, markersize=12)
        for m in models_present
    ]
    prior_labels = [abbr[m] for m in models_present]

    # 2) Activation legend (styles) – black lines with linestyle/markers
    # Let seaborn pick the default mapping; we emulate a solid for tanh and dashed for ReLU.
    act_style = {
        "tanh": dict(linestyle='-'),#, marker='o'),
        "ReLU": dict(linestyle='--'),#, marker='s'),
    }
    activation_handles = [
        Line2D([0],[0], color='black', linewidth=2, markersize=12, **act_style[act])
        for act in activation_order
        if (df_all["activation"] == act).any()
    ]
    activation_labels = [act for act in activation_order if (df_all["activation"] == act).any()]

    # Place legends: priors on top center, activations below it
    # (Adjust bbox_to_anchor if you prefer side-by-side or bottom placement.)
    if prior_handles:
        leg1 = fig.legend(
            prior_handles, prior_labels,
            title="Prior",
            loc="upper right",
            ncol=len(prior_handles),
            frameon=True,
            bbox_to_anchor=(0.7, 1.02),
            fontsize = 15
        )
        fig.add_artist(leg1)
    if activation_handles:
        fig.legend(
            activation_handles, activation_labels,
            title="Activation",
            loc="upper left",
            ncol=len(activation_handles),
            frameon=True,
            bbox_to_anchor=(0.7, 1.02),
            fontsize = 15
        )
    #fig.suptitle(title, y=1.08, fontsize=18)
    plt.tight_layout(rect=[0, 0.4, 0.95, 0.95])
    #plt.savefig("figures_for_use_in_paper/friedman_sparsity_with_beta.pdf", bbox_inches="tight")
    plt.show()


In [None]:
plot_rmse_one_figure(df_all,
                     Ns=(100, 200, 500),
                     title="Original vs Correlated")


In [None]:
python3 utils/run_all_regression_models.py --model beta_tau_tanh --output_dir results/regression/single_layer/tanh/friedman_correlated/no_lambda &&
python3 utils/run_all_regression_models.py --model dirichlet_tau_tanh --output_dir results/regression/single_layer/tanh/friedman_correlated/no_lambda &&
python3 utils/run_all_regression_models.py --model dirichlet_horseshoe_tanh --output_dir results/regression/single_layer/tanh/friedman_correlated/no_lambda &&
python3 utils/run_all_regression_models.py --model dirichlet_student_t_tanh --output_dir results/regression/single_layer/tanh/friedman_correlated/no_lambda