In [12]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_dir = f"datasets/friedman"
results_dir_relu = "results/regression/single_layer/relu/friedman"
results_dir_tanh = "results/regression/single_layer/tanh/friedman"

# model_names_relu = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
# model_names_tanh = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh"]
model_names_relu = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]
model_names_tanh = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh", "Beta Horseshoe tanh", "Beta Student T tanh"]


relu_fits = {}
tanh_fits = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"{base_config_name}"  # → "type_1/GAM_N100_p8_sigma1.00_seed1"
    relu_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_relu,
        models=model_names_relu,
        include_prior=False,
    )
    
    tanh_fit = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_tanh,
        models=model_names_tanh,
        include_prior=False,
    )
    

    relu_fits[base_config_name] = relu_fit  # use clean key
    tanh_fits[base_config_name] = tanh_fit  # use clean key
    


In [None]:
data_dir_correlated = f"datasets/friedman_correlated"
results_dir_relu_correlated = "results/regression/single_layer/relu/friedman_correlated"
results_dir_tanh_correlated = "results/regression/single_layer/tanh/friedman_correlated"

# model_names_relu_correlated = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
# model_names_tanh_correlated = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh"]
model_names_relu_correlated = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]
model_names_tanh_correlated = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh", "Beta Horseshoe tanh", "Beta Student T tanh"]


relu_fits_correlated = {}
tanh_fits_correlated = {}

files = sorted(f for f in os.listdir(data_dir_correlated) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"{base_config_name}"  # → "type_1/GAM_N100_p8_sigma1.00_seed1"
    relu_fit_correlated = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_relu_correlated,
        models=model_names_relu_correlated,
        include_prior=False,
    )
    tanh_fit_correlated = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_tanh_correlated,
        models=model_names_tanh_correlated,
        include_prior=False,
    )
    

    relu_fits_correlated[base_config_name] = relu_fit_correlated  # use clean key
    tanh_fits_correlated[base_config_name] = tanh_fit_correlated  # use clean key
    


In [4]:
from utils.sparsity import forward_pass_relu, forward_pass_tanh, local_prune_weights

def compute_sparse_rmse_results(seeds, models, all_fits, get_N_sigma, forward_pass, folder,
                         sparsity=0.0, prune_fn=None):
    results = []
    posterior_means = []

    for seed in seeds:
        N, sigma = get_N_sigma(seed)
        dataset_key = f'Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}'
        path = f"datasets/{folder}/{dataset_key}.npz"

        try:
            data = np.load(path)
            X_test, y_test = data["X_test"], data["y_test"]
        except FileNotFoundError:
            print(f"[SKIP] File not found: {path}")
            continue

        for model in models:
            try:
                fit = all_fits[dataset_key][model]['posterior']
                W1_samples = fit.stan_variable("W_1")           # (S, P, H)
                W2_samples = fit.stan_variable("W_L")           # (S, H, O)
                b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
                b2_samples = fit.stan_variable("output_bias")   # (S, O)
            except KeyError:
                print(f"[SKIP] Model or posterior not found: {dataset_key} -> {model}")
                continue

            S = W1_samples.shape[0]
            rmses = np.zeros(S)
            #print(y_test.shape)
            y_hats = np.zeros((S, y_test.shape[0]))

            for i in range(S):
                W1 = W1_samples[i]
                W2 = W2_samples[i]

                # Apply pruning mask if requested
                if prune_fn is not None and sparsity > 0.0:
                    masks = prune_fn([W1, W2], sparsity)
                    W1 = W1 * masks[0]
                    #W2 = W2 * masks[1]

                y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i])
                y_hats[i] = y_hat.squeeze()  # Store the prediction for each sample
                rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test)**2))
                
            posterior_mean = np.mean(y_hats, axis=0)
            posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test.squeeze())**2))

            posterior_means.append({
                'seed': seed,
                'N': N,
                'sigma': sigma,
                'model': model,
                'sparsity': sparsity,
                'posterior_mean_rmse': posterior_mean_rmse
            })

            for i in range(S):
                results.append({
                    'seed': seed,
                    'N': N,
                    'sigma': sigma,
                    'model': model,
                    'sparsity': sparsity,
                    'rmse': rmses[i]
                })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse


sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]

seeds = [1, 2, 11]
seeds_correlated = [1, 6, 11]

def get_N_sigma(seed):
    if seed == 1:
        N=100
    elif seed == 2:
        N=200
    else:
        N=500
    sigma=1.00
    return N, sigma

def get_N_sigma_correlated(seed):
    if seed == 1:
        N=100
    elif seed == 6:
        N=200
    else:
        N=500
    sigma=1.00
    return N, sigma

In [5]:
df_rmse_relu, df_posterior_rmse_relu = {}, {}
df_rmse_relu_correlated, df_posterior_rmse_relu_correlated = {}, {}
df_rmse_tanh, df_posterior_rmse_tanh = {}, {}
df_rmse_tanh_correlated, df_posterior_rmse_tanh_correlated = {}, {}

for sparsity in sparsity_levels:
    df_rmse_relu[sparsity], df_posterior_rmse_relu[sparsity] = compute_sparse_rmse_results(
        seeds, model_names_relu, relu_fits, get_N_sigma, forward_pass_relu, folder = "friedman",
        sparsity=sparsity, prune_fn=local_prune_weights
    )
    
    df_rmse_relu_correlated[sparsity], df_posterior_rmse_relu_correlated[sparsity] = compute_sparse_rmse_results(
        seeds_correlated, model_names_relu_correlated, relu_fits_correlated, get_N_sigma_correlated, forward_pass_relu, folder = "friedman_correlated",
        sparsity=sparsity, prune_fn=local_prune_weights
    )
    
    df_rmse_tanh[sparsity], df_posterior_rmse_tanh[sparsity] = compute_sparse_rmse_results(
        seeds, model_names_tanh, tanh_fits, get_N_sigma, forward_pass_tanh, folder = "friedman",
        sparsity=sparsity, prune_fn=local_prune_weights
    )
    
    df_rmse_tanh_correlated[sparsity], df_posterior_rmse_tanh_correlated[sparsity] = compute_sparse_rmse_results(
        seeds_correlated, model_names_tanh_correlated, tanh_fits_correlated, get_N_sigma_correlated, forward_pass_tanh, folder = "friedman_correlated",
        sparsity=sparsity, prune_fn=local_prune_weights
    )

In [6]:
import pandas as pd

df_rmse_full_relu = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_relu.items()],
    ignore_index=True
)

df_rmse_full_relu_correlated = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_relu_correlated.items()],
    ignore_index=True
)

df_rmse_full_tanh = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_tanh.items()],
    ignore_index=True
)

df_rmse_full_tanh_correlated = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_tanh_correlated.items()],
    ignore_index=True
)


In [7]:
df_tanh_o = df_rmse_full_tanh.copy()
df_tanh_o["model"] = df_tanh_o["model"].str.replace(" tanh", "", regex=False)

df_tanh_c = df_rmse_full_tanh_correlated.copy()
df_tanh_c["model"] = df_tanh_c["model"].str.replace(" tanh", "", regex=False)

df_relu_o = df_rmse_full_relu.copy()
df_relu_c = df_rmse_full_relu_correlated.copy()


In [14]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
import pandas as pd
from collections import OrderedDict

# --- Your palettes and abbreviations ---
palette = {
    # "Gaussian": "C0",
    # "Regularized Horseshoe": "C1",
    "Dirichlet Horseshoe": "C2",
    "Dirichlet Student T": "C3",
    "Beta Horseshoe": "C4",
    "Beta Student T": "C5",
}
abbr = {
    # "Gaussian": "Gauss",
    # "Regularized Horseshoe": "RHS",
    "Dirichlet Horseshoe": "DHS",
    "Dirichlet Student T": "DST",
    "Beta Horseshoe": "BHS",
    "Beta Student T": "BST",
}

def make_merged_df_both(
    df_tanh_o, df_tanh_c, df_relu_o, df_relu_c,
    drop_tanh_suffix=True
):
    """Return one long df with columns: N, sparsity, rmse, model, activation, setting."""
    dfs = []
    # Tanh (optionally strip ' tanh' from model names if present)
    for df, setting in [(df_tanh_o, "Original"), (df_tanh_c, "Correlated")]:
        d = df.copy()
        if drop_tanh_suffix and " tanh" in "".join(d["model"].unique()):
            d["model"] = d["model"].str.replace(" tanh", "", regex=False)
        d["activation"] = "tanh"
        d["setting"] = setting
        dfs.append(d)
    # ReLU
    for df, setting in [(df_relu_o, "Original"), (df_relu_c, "Correlated")]:
        d = df.copy()
        d["activation"] = "ReLU"
        d["setting"] = setting
        dfs.append(d)
    out = pd.concat(dfs, ignore_index=True)

    # Keep only models that exist in BOTH activations so legend doesn't show ghosts
    models_tanh = set(out.loc[out.activation=="tanh","model"].unique())
    models_relu = set(out.loc[out.activation=="ReLU","model"].unique())
    common_models = sorted(list(models_tanh & models_relu))
    if common_models:
        out = out[out["model"].isin(common_models)]
    return out

df_all_both = make_merged_df_both(df_tanh_o, df_tanh_c, df_relu_o, df_relu_c)

def make_merged_df(
    df_tanh_o, df_tanh_c,
    drop_tanh_suffix=True
):
    """Return one long df with columns: N, sparsity, rmse, model, activation, setting."""
    dfs = []
    # Tanh (optionally strip ' tanh' from model names if present)
    for df, setting in [(df_tanh_o, "Original"), (df_tanh_c, "Correlated")]:
        d = df.copy()
        if drop_tanh_suffix and " tanh" in "".join(d["model"].unique()):
            d["model"] = d["model"].str.replace(" tanh", "", regex=False)
        d["activation"] = "tanh"
        d["setting"] = setting
        dfs.append(d)
    
    out = pd.concat(dfs, ignore_index=True)

    # Keep only models that exist in BOTH activations so legend doesn't show ghosts
    models_tanh = set(out.loc[out.activation=="tanh","model"].unique())
    common_models = sorted(list(models_tanh))
    if common_models:
        out = out[out["model"].isin(common_models)]
    return out

df_all = make_merged_df(df_tanh_o, df_tanh_c)


In [17]:
def plot_rmse_one_figure_both(
    df_all,
    Ns=(100, 200, 500), figsize=(12, 12), title="Original vs Correlated (tanh vs ReLU)"
):
    

    # Orderings
    setting_order = ["Original", "Correlated"]
    activation_order = ["tanh", "ReLU"]

    # Seaborn aesthetics (keeps your 'talk' sizing / whitegrid)
    #sns.set_context("talk")
    sns.set_style("whitegrid")
    plt.rcParams.update({
        "axes.spines.top": False,
        "axes.spines.right": False,
        #"axes.titleweight": "semibold",
        "legend.frameon": True
    })

    fig, axes = plt.subplots(2, len(Ns), figsize=figsize, sharex=True, sharey=False)
    if len(Ns) == 1:
        axes = axes.reshape(2, 1)

    # We’ll plot using seaborn’s style mapping (style=activation, markers=True, dashes=True)
    # so tanh vs ReLU are visually distinct and consistent across the grid.
    for j, Nval in enumerate(Ns):
        for i, setting in enumerate(setting_order):
            ax = axes[i, j]
            dfN = df_all[(df_all["N"] == Nval) & (df_all["setting"] == setting)].copy()
            # Safety: if empty, skip
            if dfN.empty:
                ax.set_visible(False)
                continue

            # Use abbreviated labels on the legend (we’ll build custom legends later anyway)
            dfN["model_abbr"] = dfN["model"].map(lambda m: abbr.get(m, m))

            sns.lineplot(
                data=dfN,
                x="sparsity",
                y="rmse",
                hue="model_abbr",      # color = prior (abbr)
                style="activation",    # style = activation
                markers=True,
                dashes=True,
                palette={abbr[k]: v for k, v in palette.items() if k in dfN["model"].unique()},
                hue_order=[abbr[m] for m in sorted(dfN["model"].unique(), key=lambda x: list(palette).index(x) if x in palette else 999)],
                style_order=activation_order,
                errorbar=None,
                ax=ax,
            )
            #ax.set_title(f"N={Nval}")
            ax.set_title(f"N={Nval}", fontweight="normal", fontsize=15)

            ax.set_xlabel("Sparsity", fontsize=15)
            ax.set_ylabel("RMSE" if j == 0 else "", fontsize=15)
            ax.tick_params(axis='both', labelsize=10)
            ax.grid(True, which="major", alpha=0.25)
            if ax.legend_:  # remove local legends
                ax.legend_.remove()

    # ---------- Build two clean, global legends ----------
    # 1) Prior legend (colors), using abbreviations in desired order present in data
    models_present = []
    for m in ["Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]:
        if (df_all["model"] == m).any():
            models_present.append(m)
    prior_handles = [
        Line2D([0],[0], color=palette[m], marker='o', linestyle='-', linewidth=2, markersize=12)
        for m in models_present
    ]
    prior_labels = [abbr[m] for m in models_present]

    # 2) Activation legend (styles) – black lines with linestyle/markers
    # Let seaborn pick the default mapping; we emulate a solid for tanh and dashed for ReLU.
    act_style = {
        "tanh": dict(linestyle='-'),#, marker='o'),
        "ReLU": dict(linestyle='--'),#, marker='s'),
    }
    activation_handles = [
        Line2D([0],[0], color='black', linewidth=2, markersize=12, **act_style[act])
        for act in activation_order
        if (df_all["activation"] == act).any()
    ]
    activation_labels = [act for act in activation_order if (df_all["activation"] == act).any()]

    # Place legends: priors on top center, activations below it
    # (Adjust bbox_to_anchor if you prefer side-by-side or bottom placement.)
    if prior_handles:
        leg1 = fig.legend(
            prior_handles, prior_labels,
            title="Prior",
            loc="upper right",
            ncol=len(prior_handles),
            frameon=True,
            bbox_to_anchor=(0.7, 1.02),
            fontsize = 15
        )
        fig.add_artist(leg1)
    if activation_handles:
        fig.legend(
            activation_handles, activation_labels,
            title="Activation",
            loc="upper left",
            ncol=len(activation_handles),
            frameon=True,
            bbox_to_anchor=(0.7, 1.02),
            fontsize = 15
        )
    #fig.suptitle(title, y=1.08, fontsize=18)
    plt.tight_layout(rect=[0, 0.4, 0.95, 0.95])
    plt.savefig("figures_for_use_in_paper/friedman_sparsity_with_beta.pdf", bbox_inches="tight")
    plt.show()

def plot_rmse_one_figure(
    df_all,
    Ns=(100, 200, 500), figsize=(12, 12), title="Original vs Correlated (tanh vs ReLU)"
):
    

    # Orderings
    setting_order = ["Original", "Correlated"]

    # Seaborn aesthetics (keeps your 'talk' sizing / whitegrid)
    #sns.set_context("talk")
    sns.set_style("whitegrid")
    plt.rcParams.update({
        "axes.spines.top": False,
        "axes.spines.right": False,
        #"axes.titleweight": "semibold",
        "legend.frameon": True
    })

    fig, axes = plt.subplots(2, len(Ns), figsize=figsize, sharex=True, sharey=False)
    if len(Ns) == 1:
        axes = axes.reshape(2, 1)

    # We’ll plot using seaborn’s style mapping (style=activation, markers=True, dashes=True)
    # so tanh vs ReLU are visually distinct and consistent across the grid.
    for j, Nval in enumerate(Ns):
        for i, setting in enumerate(setting_order):
            ax = axes[i, j]
            dfN = df_all[(df_all["N"] == Nval) & (df_all["setting"] == setting)].copy()
            # Safety: if empty, skip
            if dfN.empty:
                ax.set_visible(False)
                continue

            # Use abbreviated labels on the legend (we’ll build custom legends later anyway)
            dfN["model_abbr"] = dfN["model"].map(lambda m: abbr.get(m, m))

            sns.lineplot(
                data=dfN,
                x="sparsity",
                y="rmse",
                hue="model_abbr",      # color = prior (abbr)
                style="activation",    # style = activation
                markers=True,
                dashes=True,
                palette={abbr[k]: v for k, v in palette.items() if k in dfN["model"].unique()},
                hue_order=[abbr[m] for m in sorted(dfN["model"].unique(), key=lambda x: list(palette).index(x) if x in palette else 999)],
                errorbar=None,
                ax=ax,
            )
            #ax.set_title(f"N={Nval}")
            ax.set_title(f"N={Nval}", fontweight="normal", fontsize=15)

            ax.set_xlabel("Sparsity", fontsize=15)
            ax.set_ylabel("RMSE" if j == 0 else "", fontsize=15)
            ax.tick_params(axis='both', labelsize=10)
            ax.grid(True, which="major", alpha=0.25)
            if ax.legend_:  # remove local legends
                ax.legend_.remove()

    # ---------- Build two clean, global legends ----------
    # 1) Prior legend (colors), using abbreviations in desired order present in data
    models_present = []
    # for m in ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]:
    for m in ["Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]:
        if (df_all["model"] == m).any():
            models_present.append(m)
    prior_handles = [
        Line2D([0],[0], color=palette[m], marker='o', linestyle='-', linewidth=2, markersize=12)
        for m in models_present
    ]
    prior_labels = [abbr[m] for m in models_present]

    # Place legends: priors on top center, activations below it
    # (Adjust bbox_to_anchor if you prefer side-by-side or bottom placement.)
    if prior_handles:
        leg1 = fig.legend(
            prior_handles, prior_labels,
            title="Prior",
            loc="upper right",
            ncol=len(prior_handles),
            frameon=True,
            bbox_to_anchor=(0.7, 1.02),
            fontsize = 15
        )
        fig.add_artist(leg1)
    #fig.suptitle(title, y=1.08, fontsize=18)
    plt.tight_layout(rect=[0, 0.4, 0.95, 0.95])
    plt.savefig("figures_for_use_in_paper/friedman_sparsity_tanh_with_beta.pdf", bbox_inches="tight")
    plt.show()


In [None]:
plot_rmse_one_figure(df_all,
                     Ns=(100, 200, 500),
                     title="Original vs Correlated")


## TEST DIFFERENT PRUNING SCHEME

In [15]:
import numpy as np

def build_global_mask_from_posterior(
    W_samples,
    sparsity,
    method="Eabs",          # "Eabs" or "Eabs_stability"
    stability_quantile=0.1, # used if method="Eabs_stability"
    prune_smallest=True
):
    """
    W_samples: array (S, ..., ...) posterior draws of a weight matrix.
    sparsity: fraction to prune (q). Keeps (1-q).
    Returns mask with same trailing shape as one draw, dtype float {0,1}.
    """
    assert 0.0 <= sparsity < 1.0
    S = W_samples.shape[0]
    W_abs = np.abs(W_samples)  # (S, ...)

    # Importance score a = E|w|
    a = W_abs.mean(axis=0)     # (..., ...)

    if method == "Eabs":
        score = a
    elif method == "Eabs_stability":
        # Stability proxy pi = P(|w| > t), where t is a small global quantile of |w|
        t = np.quantile(W_abs.reshape(S, -1), stability_quantile)
        pi = (W_abs > t).mean(axis=0)
        # Combine: emphasize both "large on average" and "consistently non-tiny"
        score = a * pi
    else:
        raise ValueError("method must be 'Eabs' or 'Eabs_stability'")

    # Decide how many to prune
    num_params = score.size
    k_prune = int(np.floor(sparsity * num_params))
    if k_prune == 0:
        return np.ones_like(score, dtype=float)

    flat = score.reshape(-1)

    if prune_smallest:
        # prune lowest scores
        thresh = np.partition(flat, k_prune - 1)[k_prune - 1]
        mask = (score > thresh).astype(float)
        # if ties create too many kept/pruned, fix deterministically
        # (rare but possible with many equal scores)
        if mask.sum() > num_params - k_prune:
            # drop some tied-at-threshold entries
            idx_tied = np.where(score.reshape(-1) == thresh)[0]
            need_drop = int(mask.sum() - (num_params - k_prune))
            if need_drop > 0:
                mask_flat = mask.reshape(-1)
                mask_flat[idx_tied[:need_drop]] = 0.0
                mask = mask_flat.reshape(score.shape)
        elif mask.sum() < num_params - k_prune:
            # add some tied entries if we kept too few
            idx_tied = np.where(score.reshape(-1) == thresh)[0]
            need_add = int((num_params - k_prune) - mask.sum())
            if need_add > 0:
                mask_flat = mask.reshape(-1)
                # add back from tied
                add_candidates = idx_tied[mask_flat[idx_tied] == 0.0]
                mask_flat[add_candidates[:need_add]] = 1.0
                mask = mask_flat.reshape(score.shape)
    else:
        # prune largest (not typical)
        thresh = np.partition(flat, num_params - k_prune)[num_params - k_prune]
        mask = (score < thresh).astype(float)

    return mask


def precompute_global_masks(
    all_fits,
    dataset_key,
    model,
    sparsity_levels,
    prune_W2=False,
    method="Eabs_stability"
):
    """
    Returns dict: sparsity -> (mask_W1, mask_W2 or None)
    """
    fit = all_fits[dataset_key][model]["posterior"]

    W1_samples = fit.stan_variable("W_1")  # (S, P, H)
    W2_samples = fit.stan_variable("W_L")  # (S, H, O) or (S, H) depending on O

    masks = {}
    for q in sparsity_levels:
        mask_W1 = build_global_mask_from_posterior(W1_samples, q, method=method)
        mask_W2 = None
        if prune_W2:
            mask_W2 = build_global_mask_from_posterior(W2_samples, q, method=method)
        masks[q] = (mask_W1, mask_W2)
    return masks


In [16]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from utils.generate_data import sample_gaussian_copula_uniform


def generate_Friedman_data_v2(N=200, D=10, sigma=1.0, test_size=0.2, seed=42, standardize_y=True, return_scale=True):
    np.random.seed(seed)
    X = np.random.uniform(0, 1, size=(N, D))
    x0, x1, x2, x3, x4 = X[:, 0], X[:, 1], X[:, 2], X[:, 3], X[:, 4]

    y_clean = (
        10 * np.sin(np.pi * x0 * x1) +
        20 * (x2 - 0.5) ** 2 +
        10 * x3 +
        5.0 * x4
    )
    y = y_clean + np.random.normal(0, sigma, size=N)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=seed)

    if not standardize_y:
        return (X_train, X_test, y_train, y_test) if not return_scale else (X_train, X_test, y_train, y_test, 0.0, 1.0)

    y_mean = y_train.mean()
    y_std = y_train.std() if y_train.std() > 0 else 1.0

    y_train_s = (y_train - y_mean) / y_std
    y_test_s = (y_test - y_mean) / y_std

    if return_scale:
        return X_train, X_test, y_train_s, y_test_s, y_mean, y_std
    return X_train, X_test, y_train_s, y_test_s

def generate_correlated_Friedman_data_v2(N=100, D=10, sigma=1.0, test_size=0.2, seed=42, standardize_y=True, return_scale=True):
    """
    Generate synthetic regression data for Bayesian neural network experiments.

    Parameters:
        N (int): Number of samples.
        D (int): Number of features.
        sigma (float): Noise level.
        test_size (float): Proportion for test split.
        seed (int): Random seed.
        standardize_y (bool): Whether to standardize the response variable.

    Returns:
        tuple: (X_train, X_test, y_train, y_test, y_mean, y_std) if standardize_y,
               else (X_train, X_test, y_train, y_test)
    """
    np.random.seed(seed)
    d = 10
    S_custom = np.eye(d)
    # Block 1 (vars 0..4): high Spearman, 0.7
    for i in range(0, 3):
        for j in range(i+1, 3):
            S_custom[i, j] = S_custom[j, i] = 0.8
    # Block 2 (vars 5..9): moderate Spearman, 0.4
    for i in range(5, 10):
        for j in range(i+1, 10):
            S_custom[i, j] = S_custom[j, i] = -0.5
    # Cross-block weaker, 0.15
    for i in range(0, 5):
        for j in range(5, 10):
            S_custom[i, j] = S_custom[j, i] = 0.15
    # A couple of bespoke pairs:
    S_custom[0, 9] = S_custom[9, 0] = 0.4
    S_custom[2, 7] = S_custom[7, 2] = 0.9  # very strong (will be projected if infeasible)
    S_custom[3, 4] = S_custom[4, 3] = -0.9  # very strong (will be projected if infeasible)
    S_custom[1, 6] = S_custom[6, 1] = -0.9  # very strong (will be projected if infeasible)

    U, _ = sample_gaussian_copula_uniform(n=10000, S=S_custom, random_state=123)
    #X = np.random.uniform(0, 1, size=(N, D))
    if N != U.shape[0]:
        idx = np.random.choice(U.shape[0], size=N, replace=False)
        X = U[idx, :]
    else:
        X = U

    x0, x1, x2, x3, x4 = X[:, 0], X[:, 1], X[:, 2], X[:, 3], X[:, 4]

    y_clean = (
        10 * np.sin(np.pi * x0 * x1) +
        20 * (x2 - 0.5) ** 2 +
        10 * x3 +
        5.0 * x4
    )

    y = y_clean + np.random.normal(0, sigma, size=N)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=seed)

    if not standardize_y:
        return (X_train, X_test, y_train, y_test) if not return_scale else (X_train, X_test, y_train, y_test, 0.0, 1.0)

    y_mean = y_train.mean()
    y_std = y_train.std() if y_train.std() > 0 else 1.0

    y_train_s = (y_train - y_mean) / y_std
    y_test_s = (y_test - y_mean) / y_std

    if return_scale:
        return X_train, X_test, y_train_s, y_test_s, y_mean, y_std
    return X_train, X_test, y_train_s, y_test_s

def make_large_eval_set(
    generator_fn,
    N_train,
    D,
    sigma,
    seed,
    n_eval=5000,
    standardize_y=True
):
    """
    Returns X_eval, y_eval (standardized if standardize_y=True), plus y_mean,y_std
    defined from the training split.
    """
    N_total = N_train + n_eval

    X_tr, X_te, y_tr, y_te, y_mean, y_std = generator_fn(
        N=N_total, D=D, sigma=sigma, test_size=n_eval / N_total, seed=seed,
        standardize_y=standardize_y, return_scale=True
    )
    # Now X_te has approx n_eval points (exact given test_size construction).
    return X_te, np.asarray(y_te).squeeze(), y_mean, y_std


def _logsumexp(a, axis=None):
    amax = np.max(a, axis=axis, keepdims=True)
    out = amax + np.log(np.sum(np.exp(a - amax), axis=axis, keepdims=True))
    return np.squeeze(out, axis=axis)

def gaussian_nll_pointwise(y, mu, sigma):
    return 0.5*np.log(2*np.pi*(sigma**2)) + 0.5*((y-mu)**2)/(sigma**2)

def compute_sparse_metrics_results_globalmask_large_eval(
    seeds, models, all_fits, get_N_sigma, forward_pass,
    folder,
    sparsity=0.0,
    masks_cache=None,
    prune_W2=False,
    compute_nll=True,
    noise_var_name="sigma",
    n_eval=5000,
    D=10,
    standardize_y=True,
    # pass the correct generator functions
    gen_uncorr=None,
    gen_corr=None,
):
    """
    Evaluate on a large generated test set instead of the stored tiny X_test/y_test.
    Assumes model was trained on standardized y if standardize_y=True.
    """
    assert gen_uncorr is not None and gen_corr is not None, "Pass both generator functions."

    results = []
    posterior_means = []

    # choose generator based on folder name
    def choose_gen(folder):
        return gen_corr if "friedman_correlated" in folder else gen_uncorr

    for seed in seeds:
        N, sigma = get_N_sigma(seed)
        dataset_key = f'Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}'

        # Build large eval set consistent with training split standardization
        gen_fn = choose_gen(folder)
        X_test, y_test, y_mean, y_std = make_large_eval_set(
            generator_fn=gen_fn,
            N_train=N,
            D=D,
            sigma=sigma,
            seed=seed,
            n_eval=n_eval,
            standardize_y=standardize_y
        )

        for model in models:
            try:
                fit = all_fits[dataset_key][model]['posterior']
                W1_samples = fit.stan_variable("W_1")           # (S, P, H)
                W2_samples = fit.stan_variable("W_L")           # (S, H, O)
                b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
                b2_samples = fit.stan_variable("output_bias")   # (S, O)

                noise_samples = None
                if compute_nll:
                    try:
                        noise_samples = fit.stan_variable(noise_var_name).squeeze()
                    except Exception:
                        noise_samples = None
            except KeyError:
                print(f"[SKIP] Model or posterior not found: {dataset_key} -> {model}")
                continue

            S = W1_samples.shape[0]
            y_hats = np.zeros((S, y_test.shape[0]))
            rmses = np.zeros(S)

            mask_W1 = mask_W2 = None
            if masks_cache is not None and sparsity > 0.0:
                mask_W1, mask_W2 = masks_cache[(dataset_key, model)][sparsity]

            for i in range(S):
                W1 = W1_samples[i]
                W2 = W2_samples[i]

                if mask_W1 is not None:
                    W1 = W1 * mask_W1
                if prune_W2 and (mask_W2 is not None):
                    W2 = W2 * mask_W2

                y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i]).squeeze()
                y_hats[i] = y_hat
                rmses[i] = np.sqrt(np.mean((y_hat - y_test)**2))

            # posterior mean RMSE (standardized scale)
            posterior_mean = y_hats.mean(axis=0)
            posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test)**2))

            out_pm = {
                'seed': seed,
                'N': N,
                'sigma': sigma,
                'model': model,
                'sparsity': sparsity,
                'n_eval': y_test.shape[0],
                'posterior_mean_rmse': posterior_mean_rmse,
                'posterior_mean_rmse_orig': posterior_mean_rmse * y_std,  # back to original y scale
            }

            if compute_nll:
                if noise_samples is None:
                    sig_s = np.ones(S)
                else:
                    sig_s = np.asarray(noise_samples).reshape(-1)[:S]

                # Expected NLL
                nll_draws = np.array([
                    gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i]).mean()
                    for i in range(S)
                ])
                expected_nll = nll_draws.mean()

                # Predictive (mixture) NLL
                loglik = -np.stack([
                    gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i])
                    for i in range(S)
                ], axis=0)  # (S, n_eval)
                lppd = (_logsumexp(loglik, axis=0) - np.log(S)).mean()
                predictive_nll = -lppd

                out_pm["expected_nll"] = expected_nll
                out_pm["predictive_nll"] = predictive_nll

                # Optional: predictive_nll on original scale (only if you also rescale sigma)
                # If your sigma posterior is on standardized scale, original sigma = sig_s * y_std.
                out_pm["predictive_nll_orig"] = predictive_nll + np.log(y_std)  # see note below

            posterior_means.append(out_pm)

            for i in range(S):
                row = {
                    'seed': seed,
                    'N': N,
                    'sigma': sigma,
                    'model': model,
                    'sparsity': sparsity,
                    'n_eval': y_test.shape[0],
                    'rmse': rmses[i],
                    'rmse_orig': rmses[i] * y_std
                }
                if compute_nll:
                    row["nll"] = gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i]).mean()
                results.append(row)

    return pd.DataFrame(results), pd.DataFrame(posterior_means)


In [17]:
from utils.sparsity import forward_pass_relu, forward_pass_tanh, local_prune_weights

def build_masks_cache_for_all(
    all_fits,
    dataset_keys,
    models,
    sparsity_levels,
    prune_W2=False,
    method="Eabs_stability"
):
    masks_cache = {}
    for dataset_key in dataset_keys:
        for model in models:
            try:
                masks_cache[(dataset_key, model)] = precompute_global_masks(
                    all_fits=all_fits,
                    dataset_key=dataset_key,
                    model=model,
                    sparsity_levels=sparsity_levels,
                    prune_W2=prune_W2,
                    method=method
                )
            except KeyError:
                print(f"[SKIP MASKS] Missing fit for {dataset_key} -> {model}")
    return masks_cache

seeds = [1, 2, 11]
seeds_correlated = [1, 6, 11]

sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]


def get_N_sigma(seed):
    if seed == 1:
        N=100
    elif seed == 2:
        N=200
    else:
        N=500
    sigma=1.00
    return N, sigma

def get_N_sigma_correlated(seed):
    if seed == 1:
        N=100
    elif seed == 6:
        N=200
    else:
        N=500
    sigma=1.00
    return N, sigma

# Build the list of dataset keys you actually evaluate (same keys as in your compute loop)
dataset_keys = []
for seed in seeds:
    N, sigma = get_N_sigma(seed)
    dataset_keys.append(f'Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}')

dataset_keys_corr = []
for seed in seeds_correlated:
    N, sigma = get_N_sigma_correlated(seed)
    dataset_keys_corr.append(f'Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}')


# Precompute masks once
masks_relu = build_masks_cache_for_all(relu_fits, dataset_keys, model_names_relu, sparsity_levels, prune_W2=False)
masks_relu_corr = build_masks_cache_for_all(relu_fits_correlated, dataset_keys_corr, model_names_relu_correlated, sparsity_levels, prune_W2=False)

masks_tanh = build_masks_cache_for_all(tanh_fits, dataset_keys, model_names_tanh, sparsity_levels, prune_W2=False)
masks_tanh_corr = build_masks_cache_for_all(tanh_fits_correlated, dataset_keys_corr, model_names_tanh_correlated, sparsity_levels, prune_W2=False)


# Then run evaluation across sparsity grid, now using GLOBAL masks
# df_rmse_relu, df_post_relu = {}, {}
# for q in sparsity_levels:
#     df_rmse_relu[q], df_post_relu[q] = compute_sparse_metrics_results_globalmask(
#         seeds, model_names_relu, relu_fits, get_N_sigma, forward_pass_relu, folder="friedman",
#         sparsity=q, masks_cache=masks_relu, prune_W2=False, compute_nll=True, noise_var_name="noise"
#     )

# df_rmse_relu_corr, df_post_relu_corr = {}, {}
# for q in sparsity_levels:
#     df_rmse_relu_corr[q], df_post_relu_corr[q] = compute_sparse_metrics_results_globalmask(
#         seeds_correlated, model_names_relu_correlated, relu_fits_correlated, get_N_sigma_correlated, forward_pass_relu, folder="friedman_correlated",
#         sparsity=q, masks_cache=masks_relu_corr, prune_W2=False, compute_nll=True, noise_var_name="noise"
#     )

# Example usage for tanh, original setting:
df_rmse_tanh, df_post_tanh = {}, {}

for q in sparsity_levels:
    df_rmse_tanh[q], df_post_tanh[q] = compute_sparse_metrics_results_globalmask_large_eval(
        seeds=seeds,
        models=model_names_tanh,
        all_fits=tanh_fits,
        get_N_sigma=get_N_sigma,
        forward_pass=forward_pass_tanh,
        folder="friedman",
        sparsity=q,
        masks_cache=masks_tanh,
        prune_W2=False,
        compute_nll=True,
        noise_var_name="sigma",
        n_eval=5000,
        D=10,
        standardize_y=True,
        gen_uncorr=generate_Friedman_data_v2,
        gen_corr=generate_correlated_Friedman_data_v2,
    )


df_rmse_tanh_corr, df_post_tanh_corr = {}, {}
for q in sparsity_levels:
    df_rmse_tanh_corr[q], df_post_tanh_corr[q] = compute_sparse_metrics_results_globalmask_large_eval(
        seeds=seeds_correlated,
        models=model_names_tanh,
        all_fits=tanh_fits_correlated,
        get_N_sigma=get_N_sigma_correlated,
        forward_pass=forward_pass_tanh,
        folder="friedman_correlated",
        sparsity=q,
        masks_cache=masks_tanh_corr,
        prune_W2=False,
        compute_nll=True,
        noise_var_name="sigma",
        n_eval=5000,
        D=10,
        standardize_y=True,
        gen_uncorr=generate_Friedman_data_v2,
        gen_corr=generate_correlated_Friedman_data_v2,
    )


In [18]:
import pandas as pd

# df_post_relu_full = pd.concat(
#     [df.assign(sparsity=sparsity) for sparsity, df in df_post_relu.items()],
#     ignore_index=True
# )

# df_post_relu_corr_full = pd.concat(
#     [df.assign(sparsity=sparsity) for sparsity, df in df_post_relu_corr.items()],
#     ignore_index=True
# )

df_post_tanh_full = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_post_tanh.items()],
    ignore_index=True
)

df_post_tanh_corr_full = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_post_tanh_corr.items()],
    ignore_index=True
)

In [19]:
df_tanh_o = df_post_tanh_full.copy()
df_tanh_o["model"] = df_tanh_o["model"].str.replace(" tanh", "", regex=False)

df_tanh_c = df_post_tanh_corr_full.copy()
df_tanh_c["model"] = df_tanh_c["model"].str.replace(" tanh", "", regex=False)


In [20]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
import pandas as pd
from collections import OrderedDict

# --- Your palettes and abbreviations ---
palette = {
    "Gaussian": "C0",
    "Regularized Horseshoe": "C1",
    "Dirichlet Horseshoe": "C2",
    "Dirichlet Student T": "C3",
    "Beta Horseshoe": "C4",
    "Beta Student T": "C5",
}
abbr = {
    "Gaussian": "Gauss",
    "Regularized Horseshoe": "RHS",
    "Dirichlet Horseshoe": "DHS",
    "Dirichlet Student T": "DST",
    "Beta Horseshoe": "BHS",
    "Beta Student T": "BST",
}

def make_merged_df_both(
    df_tanh_o, df_tanh_c, df_relu_o, df_relu_c,
    drop_tanh_suffix=True
):
    """Return one long df with columns: N, sparsity, rmse, model, activation, setting."""
    dfs = []
    # Tanh (optionally strip ' tanh' from model names if present)
    for df, setting in [(df_tanh_o, "Original"), (df_tanh_c, "Correlated")]:
        d = df.copy()
        if drop_tanh_suffix and " tanh" in "".join(d["model"].unique()):
            d["model"] = d["model"].str.replace(" tanh", "", regex=False)
        d["activation"] = "tanh"
        d["setting"] = setting
        dfs.append(d)
    # ReLU
    for df, setting in [(df_relu_o, "Original"), (df_relu_c, "Correlated")]:
        d = df.copy()
        d["activation"] = "ReLU"
        d["setting"] = setting
        dfs.append(d)
    out = pd.concat(dfs, ignore_index=True)

    # Keep only models that exist in BOTH activations so legend doesn't show ghosts
    models_tanh = set(out.loc[out.activation=="tanh","model"].unique())
    models_relu = set(out.loc[out.activation=="ReLU","model"].unique())
    common_models = sorted(list(models_tanh & models_relu))
    if common_models:
        out = out[out["model"].isin(common_models)]
    return out

# df_all_both = make_merged_df_both(df_tanh_o, df_tanh_c, df_relu_o, df_relu_c)

def make_merged_df(
    df_tanh_o, df_tanh_c,
    drop_tanh_suffix=True
):
    """Return one long df with columns: N, sparsity, rmse, model, activation, setting."""
    dfs = []
    # Tanh (optionally strip ' tanh' from model names if present)
    for df, setting in [(df_tanh_o, "Original"), (df_tanh_c, "Correlated")]:
        d = df.copy()
        if drop_tanh_suffix and " tanh" in "".join(d["model"].unique()):
            d["model"] = d["model"].str.replace(" tanh", "", regex=False)
        d["activation"] = "tanh"
        d["setting"] = setting
        dfs.append(d)
    
    out = pd.concat(dfs, ignore_index=True)

    # Keep only models that exist in BOTH activations so legend doesn't show ghosts
    models_tanh = set(out.loc[out.activation=="tanh","model"].unique())
    common_models = sorted(list(models_tanh))
    if common_models:
        out = out[out["model"].isin(common_models)]
    return out

df_all = make_merged_df(df_tanh_o, df_tanh_c)


In [21]:
def plot_rmse_one_figure(
    df_all,
    Ns=(100, 200, 500), figsize=(12, 12), title="Original vs Correlated (tanh vs ReLU)"
):
    

    # Orderings
    setting_order = ["Original", "Correlated"]

    # Seaborn aesthetics (keeps your 'talk' sizing / whitegrid)
    #sns.set_context("talk")
    sns.set_style("whitegrid")
    plt.rcParams.update({
        "axes.spines.top": False,
        "axes.spines.right": False,
        #"axes.titleweight": "semibold",
        "legend.frameon": True
    })

    fig, axes = plt.subplots(2, len(Ns), figsize=figsize, sharex=True, sharey=False)
    if len(Ns) == 1:
        axes = axes.reshape(2, 1)

    # We’ll plot using seaborn’s style mapping (style=activation, markers=True, dashes=True)
    # so tanh vs ReLU are visually distinct and consistent across the grid.
    for j, Nval in enumerate(Ns):
        for i, setting in enumerate(setting_order):
            ax = axes[i, j]
            dfN = df_all[(df_all["N"] == Nval) & (df_all["setting"] == setting)].copy()
            # Safety: if empty, skip
            if dfN.empty:
                ax.set_visible(False)
                continue

            # Use abbreviated labels on the legend (we’ll build custom legends later anyway)
            dfN["model_abbr"] = dfN["model"].map(lambda m: abbr.get(m, m))

            sns.lineplot(
                data=dfN,
                x="sparsity",
                y="posterior_mean_rmse",
                hue="model_abbr",      # color = prior (abbr)
                style="activation",    # style = activation
                markers=True,
                dashes=True,
                palette={abbr[k]: v for k, v in palette.items() if k in dfN["model"].unique()},
                hue_order=[abbr[m] for m in sorted(dfN["model"].unique(), key=lambda x: list(palette).index(x) if x in palette else 999)],
                errorbar=None,
                ax=ax,
            )
            #ax.set_title(f"N={Nval}")
            ax.set_title(f"N={Nval}", fontweight="normal", fontsize=15)

            ax.set_xlabel("Sparsity", fontsize=15)
            ax.set_ylabel("RMSE" if j == 0 else "", fontsize=15)
            ax.tick_params(axis='both', labelsize=10)
            ax.grid(True, which="major", alpha=0.25)
            if ax.legend_:  # remove local legends
                ax.legend_.remove()

    # ---------- Build two clean, global legends ----------
    # 1) Prior legend (colors), using abbreviations in desired order present in data
    models_present = []
    #for m in ["Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]:
    for m in ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]:
        if (df_all["model"] == m).any():
            models_present.append(m)
    prior_handles = [
        Line2D([0],[0], color=palette[m], marker='o', linestyle='-', linewidth=2, markersize=12)
        for m in models_present
    ]
    prior_labels = [abbr[m] for m in models_present]

    # Place legends: priors on top center, activations below it
    # (Adjust bbox_to_anchor if you prefer side-by-side or bottom placement.)
    if prior_handles:
        leg1 = fig.legend(
            prior_handles, prior_labels,
            title="Prior",
            loc="upper right",
            ncol=len(prior_handles),
            frameon=True,
            bbox_to_anchor=(0.7, 1.02),
            fontsize = 15
        )
        fig.add_artist(leg1)
    #fig.suptitle(title, y=1.08, fontsize=18)
    plt.tight_layout(rect=[0, 0.4, 0.95, 0.95])
    #plt.savefig("figures_for_use_in_paper/friedman_sparsity_tanh_with_beta.pdf", bbox_inches="tight")
    plt.show()

In [None]:
plot_rmse_one_figure(df_all,
                     Ns=(100, 200, 500),
                     title="Original vs Correlated")

python3 utils/run_all_regression_models.py --model dirichlet_horseshoe_tanh_nodewise --output_dir results/regression/single_layer/tanh/friedman_correlated && python3 utils/run_all_regression_models.py --model dirichlet_student_t_tanh_nodewise --output_dir results/regression/single_layer/tanh/friedman_correlated && python3 utils/run_all_regression_models.py --model beta_horseshoe_tanh_nodewise --output_dir results/regression/single_layer/tanh/friedman_correlated && python3 utils/run_all_regression_models.py --model beta_student_t_tanh_nodewise --output_dir results/regression/single_layer/tanh/friedman_correlated

## VISUALIZE

In [9]:
from utils.visualize_networks import compute_activation_frequency, extract_all_pruned_means, plot_all_networks_subplots_activations
path = "datasets/friedman/many/Friedman_N100_p10_sigma1.00_seed6.npz"
data = np.load(path)
x_train = data["X_train"]

In [10]:
node_activation_colors = {
    model_name: compute_activation_frequency(fit, x_train)
    for model_name, fit in tanh_fits['Friedman_N100_p10_sigma1.00_seed1'].items()
}

# Flatten and find the global maximum
all_freqs = np.concatenate(list(node_activation_colors.values()))
global_max = all_freqs.max()

node_activation_colors_correlated = {
    model_name: compute_activation_frequency(fit, x_train)
    for model_name, fit in tanh_fits_correlated['Friedman_N100_p10_sigma1.00_seed1'].items()
}

# Flatten and find the global maximum
all_freqs_correlated = np.concatenate(list(node_activation_colors_correlated.values()))
global_max_correlated = all_freqs_correlated.max()

In [11]:
P = 10
H = 16
L = 1
out_nodes = 1
layer_sizes = [P] + [H]*L + [out_nodes]

layer_structure = {
    'input_to_hidden': {'name': 'W_1', 'shape': (P, H)},
    'hidden_to_output': {'name': 'W_L', 'shape': (H, out_nodes)}
}

sparsity_level = 0.9

In [None]:
pruned_model_means = extract_all_pruned_means(tanh_fits['Friedman_N100_p10_sigma1.00_seed1'], layer_structure, sparsity_level)

p1, widths_1 = plot_all_networks_subplots_activations(pruned_model_means, layer_sizes, node_activation_colors, activation_color_max=global_max, signed_colors=False)


In [None]:
pruned_model_means_correlated = extract_all_pruned_means(tanh_fits_correlated['Friedman_N100_p10_sigma1.00_seed1'], layer_structure, sparsity_level)

p1, widths_1 = plot_all_networks_subplots_activations(pruned_model_means, layer_sizes, node_activation_colors, activation_color_max=global_max, signed_colors=False)


## Node pruning

In [None]:
from utils.sparsity import compute_sparse_rmse_results, prune_nodes_by_output_weights

df_rmse_node_relu, df_posterior_rmse_node_relu = {}, {}
df_rmse_node_tanh, df_posterior_rmse_node_tanh = {}, {}

def nodes_to_sparsity(nodes_to_prune_list, total_nodes):
    """
    Convert a list of node counts to prune into sparsity levels.

    Args:
        nodes_to_prune_list: list of integers (number of nodes to prune).
        total_nodes: total number of nodes in the layer.

    Returns:
        List of sparsity levels between 0.0 and 1.0.
    """
    sparsity_levels = [round(n_prune / total_nodes, 4) for n_prune in nodes_to_prune_list]
    return sparsity_levels

# Suppose you have 16 nodes in the hidden layer
total_nodes = 16
nodes_to_prune = [0, 1, 2, 4, 6, 8, 10, 12, 14]

node_sparsity = nodes_to_sparsity(nodes_to_prune, total_nodes)
print(node_sparsity)  
# Output: [0.0, 0.0625, 0.125, 0.25, 0.5, 0.75, 0.875]

for sparsity in node_sparsity:
    df_rmse_node_relu[sparsity], df_posterior_rmse_node_relu[sparsity] = compute_sparse_rmse_results(
    seeds, model_names_relu, relu_fits, get_N_sigma, forward_pass_relu,
    sparsity=sparsity, prune_fn=prune_nodes_by_output_weights
)
    
    df_rmse_node_tanh[sparsity], df_posterior_rmse_node_tanh[sparsity] = compute_sparse_rmse_results(
    seeds, model_names_tanh, tanh_fits, get_N_sigma, forward_pass_tanh,
    sparsity=sparsity, prune_fn=prune_nodes_by_output_weights
)

In [15]:
import pandas as pd

df_rmse_full_node_relu = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_node_relu.items()],
    ignore_index=True
)

df_rmse_full_node_tanh = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_node_tanh.items()],
    ignore_index=True
)




In [16]:
total_nodes = 16  # adjust this if needed

df_rmse_full_node_relu['nodes_pruned'] = (df_rmse_full_node_relu['sparsity'] * total_nodes).astype(int)
df_rmse_full_node_tanh['nodes_pruned'] = (df_rmse_full_node_tanh['sparsity'] * total_nodes).astype(int)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from collections import OrderedDict

# Clean model names
df_rmse_full_node_relu = df_rmse_full_node_relu.copy()
df_rmse_full_node_relu["model"] = df_rmse_full_node_relu["model"].str.replace(" tanh", "", regex=False)

df_rmse_full_node_tanh = df_rmse_full_node_tanh.copy()
df_rmse_full_node_tanh["model"] = df_rmse_full_node_tanh["model"].str.replace(" tanh", "", regex=False)

# Define consistent color palette
custom_palette = {
    "Gaussian": "C0",
    "Regularized Horseshoe": "C1",
    "Dirichlet Horseshoe": "C2",
    "Dirichlet Student T": "C3",
}

# Set up plot
fig, axes = plt.subplots(1, 2, figsize=(16, 5), sharex=True, sharey=True)

activation_data = [("ReLU", df_rmse_full_node_relu), ("tanh", df_rmse_full_node_tanh)]
all_handles_labels = []

# Plot
for idx, (name, df) in enumerate(activation_data):
    ax = axes[idx]
    
    sns.lineplot(
        data=df,
        x='nodes_pruned', y='rmse',
        hue='model', style='N', marker='o', errorbar=None, ax=ax,
        palette=custom_palette
    )

    handles, labels = ax.get_legend_handles_labels()
    all_handles_labels.extend(zip(handles, labels))
    ax.get_legend().remove()

    ax.set_title(f"{name} activation")
    ax.set_ylabel("RMSE")
    ax.set_xlabel("Nodes pruned")
    ax.grid(True)

# Clean legend
legend_dict = OrderedDict()
for handle, label in all_handles_labels:
    if label not in {"model", "N"} and label not in legend_dict:
        legend_dict[label] = handle

desired_order = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
filtered = [(legend_dict[label], label) for label in desired_order if label in legend_dict]

if filtered:
    filtered_handles, filtered_labels = zip(*filtered)
    fig.legend(
        filtered_handles,
        filtered_labels,
        title="Model",
        loc="upper center",
        bbox_to_anchor=(0.5, 1.05),
        ncol=2,
    )

plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.show()

## Visualize

In [21]:
from utils.visualize_networks import compute_activation_frequency, extract_all_pruned_means, plot_all_networks_subplots_activations
path = "datasets/friedman/Friedman_N100_p10_sigma1.00_seed1.npz"
data = np.load(path)
x_train = data["X_train"]

In [None]:
node_activation_colors = {
    model_name: compute_activation_frequency(fit, x_train)
    for model_name, fit in relu_fits['Friedman_N100_p10_sigma1.00_seed1'].items()
}

# Flatten and find the global maximum
all_freqs = np.concatenate(list(node_activation_colors.values()))
global_max = all_freqs.max()
print(global_max)

## NODE PRUNE VISUALIZE

In [None]:
from utils.visualize_networks import extract_all_pruned_node_means, plot_all_networks_subplots_activations
num_nodes_to_prune = 6  # for example
pruned_model_means_nodes = extract_all_pruned_node_means(relu_fits['Friedman_N100_p10_sigma1.00_seed1'], layer_structure, num_nodes_to_prune)

p_nodes, widths_nodes = plot_all_networks_subplots_activations(
    pruned_model_means_nodes, layer_sizes, node_activation_colors,
    activation_color_max=global_max, signed_colors=False
)
