In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_dir = f"datasets/friedman_correlated"
results_dir_tanh_corr = "results/regression/single_layer/tanh/friedman_correlated"

model_names_tanh_corr = ["Dirichlet Horseshoe tanh nodewise", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh nodewise", "Dirichlet Student T tanh", "Beta Horseshoe tanh nodewise", "Beta Horseshoe tanh"]

tanh_fits_corr = {}

files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")
    full_config_path = f"{base_config_name}"
    
    tanh_fit_corr = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_tanh_corr,
        models=model_names_tanh_corr,
        include_prior=False,
    )
    
    tanh_fits_corr[base_config_name] = tanh_fit_corr
    


In [3]:
import re
import numpy as np
import pandas as pd
from utils.generate_data import generate_Friedman_data, generate_correlated_Friedman_data

_FRIEDMAN_KEY = re.compile(r"Friedman_N(\d+)_p\d+_sigma([\d.]+)_seed(\d+)")

def extract_friedman_metadata(key: str):
    """
    Parse 'Friedman_N{N}_p10_sigma{sigma}_seed{seed}' -> (N:int, sigma:float, seed:int)
    Returns (None, None, None) if it doesn't match.
    """
    m = _FRIEDMAN_KEY.search(key)
    if not m:
        return None, None, None
    N = int(m.group(1))
    sigma = float(m.group(2))
    seed = int(m.group(3))
    return N, sigma, seed

def forward_pass_tanh(X, W1, b1, WL, bL):
    # X: (N,P), W1: (P,H), b1: (H,), WL: (H,O), bL: (O,)
    H = np.tanh(X @ W1 + b1.reshape(1, -1))
    Y = H @ WL + bL.reshape(1, -1)
    return Y  # (N,O)


In [4]:
import re
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error

def _find_run_key(fits_by_run, *, N, D, sigma, seed, correlated=False):
    """
    Tries to find the right key inside e.g. tanh_fits for the requested (N, D, sigma, seed).
    Works even if your sigma formatting differs a bit (e.g. 1, 1.0, 1.00).
    """
    # Common naming you showed: Friedman_N100_p10_sigma1.00_seed3
    # If you also have "CorrelatedFriedman" etc, we try both patterns.
    sigma_strs = [
        f"{sigma}",
        f"{sigma:.1f}",
        f"{sigma:.2f}",
        f"{sigma:.3f}",
    ]
    candidates = []
    for s in sigma_strs:
        if correlated:
            candidates.append(f"CorrelatedFriedman_N{N}_p{D}_sigma{s}_seed{seed}")
        candidates.append(f"Friedman_N{N}_p{D}_sigma{s}_seed{seed}")

    for k in candidates:
        if k in fits_by_run:
            return k

    # If exact match not found, do a regex search (robust to extra prefixes/suffixes)
    # Example matches: "...Friedman_N100_p10_sigma1.00_seed3..."
    sig_pat = "|".join(re.escape(s) for s in sigma_strs)
    base = r"Friedman" if not correlated else r"(?:CorrelatedFriedman|Friedman)"
    pat = re.compile(
        rf"{base}_N{N}_p{D}_sigma(?:{sig_pat})_seed{seed}"
    )
    matches = [k for k in fits_by_run.keys() if pat.search(k)]
    if len(matches) == 1:
        return matches[0]
    if len(matches) > 1:
        # Prefer the shortest/most exact-looking key
        matches = sorted(matches, key=len)
        return matches[0]

    raise KeyError(
        f"Could not find a run in fits_by_run for N={N}, D={D}, sigma={sigma}, seed={seed}, correlated={correlated}.\n"
        f"Example expected key: 'Friedman_N{N}_p{D}_sigma{sigma:.2f}_seed{seed}'."
    )


def evaluate_posterior_on_multiple_testsets(
    fits_by_run,            # <-- pass tanh_fits here
    models,
    layers,
    forward_pass,
    *,
    correlated=False,
    sigma=1.0,
    D=10,
    N_train=100,
    Ns=(5000,),             # <-- run multiple N's if you want
    n_testsets=1,
    seeds=(1,),
    test_seed_base=123,
):
    """
    - You pass tanh_fits (a dict keyed by run-name like 'Friedman_N100_p10_sigma1.00_seed3')
    - For each (N, seed), we:
        1) generate n_testsets independent test sets (seed = test_seed_base + test_id)
        2) evaluate posterior mean prediction RMSE on each test set
        3) return:
           df_mean: mean RMSE over testsets, *separately* per seed (not merged)
           df_raw : per-testset RMSE rows
    """
    rows = []

    for N in Ns:
        for seed in seeds:
            run_key = _find_run_key(
                fits_by_run, N=N_train, D=D, sigma=sigma, seed=seed, correlated=correlated
            )
            fits = fits_by_run[run_key]  # <-- this is what you previously indexed manually

            for test_id in range(n_testsets):
                # Make test sets genuinely different:
                data_seed = 123

                if correlated:
                    _, X_test, y_train_raw, y_test_raw = generate_correlated_Friedman_data(
                        N=N, D=D, sigma=sigma, test_size=0.2, seed=data_seed, standardize_y=False
                    )
                else:
                    _, X_test, y_train_raw, y_test_raw = generate_Friedman_data(
                        N=N, D=D, sigma=sigma, test_size=0.2, seed=data_seed, standardize_y=False
                    )

                y_train_mean = y_train_raw.mean()
                y_train_std = y_train_raw.std()

                y_test = (y_test_raw - y_train_mean) / y_train_std
                y_test_np = y_test.reshape(-1)

                for model in models:
                    fit = fits[model]["posterior"]

                    W1_samples = fit.stan_variable("W_1")         # (S, P, H)
                    if layers == 2:
                        W2_samples = fit.stan_variable("W_2")     # (S, H, H) or similar
                    WL_samples = fit.stan_variable("W_L")         # (S, H, O)
                    b_samples = fit.stan_variable("hidden_bias")  # (S, L, H)
                    b1_samples = b_samples[:, 0, :]
                    if layers == 2:
                        b2_samples = b_samples[:, 1, :]
                    bL_samples = fit.stan_variable("output_bias") # (S, O)

                    S = W1_samples.shape[0]
                    y_hats = np.zeros((S, y_test_np.shape[0]))

                    for i in range(S):
                        W1 = W1_samples[i]
                        WL = WL_samples[i]

                        if layers == 1:
                            y_hat = forward_pass(X_test, W1, b1_samples[i], WL, bL_samples[i])
                        else:
                            W2 = W2_samples[i]
                            y_hat = forward_pass(X_test, W1, b1_samples[i], W2, b2_samples[i], WL, bL_samples[i])

                        y_hats[i] = y_hat.squeeze()

                    post_mean = y_hats.mean(axis=0)
                    posterior_rmse_std = np.sqrt(mean_squared_error(y_test_np, post_mean))
                    posterior_rmse_rawscale = posterior_rmse_std * y_train_std

                    rows.append({
                        "run_key": run_key,
                        "model": model,
                        "N": N,
                        "D": D,
                        "sigma": sigma,
                        "correlated": correlated,
                        "seed": seed,
                        "test_set": test_id,
                        "posterior_rmse": posterior_rmse_rawscale,
                    })

    df = pd.DataFrame(rows)

    # Mean over testsets, BUT kept separate for each (model, seed, N, ...)
    group_cols = ["model", "seed", "N", "D", "sigma", "correlated"]
    df_mean = (
        df.groupby(group_cols, as_index=False)["posterior_rmse"]
          .mean()
          .rename(columns={"posterior_rmse": "mean_rmse_over_testsets (scaled)"})
          .sort_values(group_cols)
          .reset_index(drop=True)
    )

    return df_mean, df


In [17]:
df_mean_N100, df_raw_N100 = evaluate_posterior_on_multiple_testsets(
    fits_by_run=tanh_fits,
    models=list(tanh_fits['Friedman_N100_p10_sigma1.00_seed1'].keys()),
    layers=1,
    forward_pass=forward_pass_tanh,
    correlated=False,
    sigma=1.00,
    D=10,
    N_train = 100,
    Ns=[5000],
    n_testsets=1,
    seeds=[1],
)

In [14]:
df_mean_N200, df_raw_N200 = evaluate_posterior_on_multiple_testsets(
    fits_by_run=tanh_fits,
    models=list(tanh_fits['Friedman_N200_p10_sigma1.00_seed2'].keys()),
    layers=1,
    forward_pass=forward_pass_tanh,
    correlated=False,
    sigma=1.00,
    D=10,
    N_train = 200,
    Ns=[5000],
    n_testsets=5,
    seeds=[2],
)

In [15]:
df_mean_N500, df_raw_N500 = evaluate_posterior_on_multiple_testsets(
    fits_by_run=tanh_fits,
    models=list(tanh_fits['Friedman_N500_p10_sigma1.00_seed11'].keys()),
    layers=1,
    forward_pass=forward_pass_tanh,
    correlated=False,
    sigma=1.00,
    D=10,
    N_train = 500,
    Ns=[5000],
    n_testsets=5,
    seeds=[11],
)

In [11]:
df_mean_N100_corr, df_raw_N100_corr = evaluate_posterior_on_multiple_testsets(
    fits_by_run=tanh_fits_corr,
    models=list(tanh_fits_corr['Friedman_N100_p10_sigma1.00_seed1'].keys()),
    layers=1,
    forward_pass=forward_pass_tanh,
    correlated=True,
    sigma=1.00,
    D=10,
    N_train = 100,
    Ns=[5000],
    n_testsets=1,
    seeds=[1],
)

In [None]:
df_mean_N100_corr.sort_values(by="mean_rmse_over_testsets (scaled)")

In [13]:
df_mean_N200_corr, df_raw_N200_corr = evaluate_posterior_on_multiple_testsets(
    fits_by_run=tanh_fits_corr,
    models=list(tanh_fits_corr['Friedman_N200_p10_sigma1.00_seed6'].keys()),
    layers=1,
    forward_pass=forward_pass_tanh,
    correlated=True,
    sigma=1.00,
    D=10,
    N_train = 200,
    Ns=[5000],
    n_testsets=1,
    seeds=[6],
)

In [None]:
df_mean_N200_corr.sort_values(by="mean_rmse_over_testsets (scaled)")

In [15]:
df_mean_N500_corr, df_raw_N500_corr = evaluate_posterior_on_multiple_testsets(
    fits_by_run=tanh_fits_corr,
    models=list(tanh_fits_corr['Friedman_N500_p10_sigma1.00_seed11'].keys()),
    layers=1,
    forward_pass=forward_pass_tanh,
    correlated=True,
    sigma=1.00,
    D=10,
    N_train = 500,
    Ns=[5000],
    n_testsets=1,
    seeds=[11],
)

In [None]:
df_mean_N500_corr.sort_values(by="mean_rmse_over_testsets (scaled)")

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# posterior mean prediction on standardized scale (most likely)
output_mean_std_DHS = (
    tanh_fits_corr['Friedman_N500_p10_sigma1.00_seed11']
    ['Dirichlet Horseshoe tanh nodewise']['posterior']
    .stan_variable("output_test")
    .mean(axis=0)
).reshape(-1)

output_mean_std_BHS = (
    tanh_fits_corr['Friedman_N500_p10_sigma1.00_seed11']
    ['Beta Horseshoe tanh nodewise']['posterior']
    .stan_variable("output_test")
    .mean(axis=0)
).reshape(-1)

_, X_test, y_train_raw, y_test_raw = generate_correlated_Friedman_data(
    N=500, D=10, sigma=1.0, test_size=0.2, seed=11, standardize_y=False
)

y_train_raw = y_train_raw.reshape(-1)
y_test_raw  = y_test_raw.reshape(-1)

y_mean = y_train_raw.mean()
y_std  = y_train_raw.std()

# IMPORTANT: fully invert standardization if output is on standardized scale
y_pred_raw_DHS = y_mean + y_std * output_mean_std_DHS
y_pred_raw_BHS = y_mean + y_std * output_mean_std_BHS

# ---- sanity checks (very helpful) ----
print("len(output_mean_std):", len(output_mean_std_DHS))
print("len(y_train_raw):", len(y_train_raw), "len(y_test_raw):", len(y_test_raw))

# If your "output" is for TRAINING points, it should match len(y_train_raw) (80 here).
# If it matches 20, it’s probably test output. If it matches 100, it’s for all data.
# Pick the corresponding y_true.
if len(output_mean_std_DHS) == len(y_train_raw):
    y_true_raw = y_train_raw
    title_suffix = "train"
elif len(output_mean_std_DHS) == len(y_test_raw):
    y_true_raw = y_test_raw
    title_suffix = "test"
else:
    # fallback: truncate to min length (not ideal, but avoids crashing)
    # n = min(len(output_mean_std_DHS), len(y_train_raw))
    # y_true_raw = y_train_raw[:n]
    # y_pred_raw = y_pred_raw[:n]
    print(f"ERROR")

# ---- plot 1: pred vs true (scatter with y=x line) ----
plt.figure()
plt.scatter(y_true_raw, y_pred_raw_DHS, alpha=0.7, label="DHS", color="Orange")
lo = min(y_true_raw.min(), y_pred_raw_DHS.min())
hi = max(y_true_raw.max(), y_pred_raw_DHS.max())
plt.scatter(y_true_raw, y_pred_raw_BHS, alpha=0.7, label="BHS", color="Green")
plt.plot([lo, hi], [lo, hi])  # identity line
plt.xlabel("True y (raw)")
plt.ylabel("Predicted y (raw)")
plt.title(f"Pred vs True ({title_suffix})")
plt.legend()
plt.grid(True, alpha=0.3)

# ---- plot 2: index plot (both series over observations) ----
plt.figure()
idx = np.arange(len(y_true_raw))
plt.scatter(idx, y_true_raw, label="true", alpha=0.8)
plt.scatter(idx, y_pred_raw_DHS, label="pred DHS", alpha=0.8)
plt.scatter(idx, y_pred_raw_BHS, label="pred BHS", alpha=0.8)
plt.xlabel("Index")
plt.ylabel("y (raw)")
plt.title(f"True and Pred over index ({title_suffix})")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()


## TEST DIFFERENT PRUNING SCHEME

In [5]:
import numpy as np

def build_global_mask_from_posterior(
    W_samples,
    sparsity,
    method="Eabs",          # "Eabs" or "Eabs_stability"
    stability_quantile=0.1, # used if method="Eabs_stability"
    prune_smallest=True
):
    """
    W_samples: array (S, ..., ...) posterior draws of a weight matrix.
    sparsity: fraction to prune (q). Keeps (1-q).
    Returns mask with same trailing shape as one draw, dtype float {0,1}.
    """
    assert 0.0 <= sparsity < 1.0
    S = W_samples.shape[0]
    W_abs = np.abs(W_samples)  # (S, ...)

    # Importance score a = E|w|
    a = W_abs.mean(axis=0)     # (..., ...)

    if method == "Eabs":
        score = a
    elif method == "Eabs_stability":
        # Stability proxy pi = P(|w| > t), where t is a small global quantile of |w|
        t = np.quantile(W_abs.reshape(S, -1), stability_quantile)
        pi = (W_abs > t).mean(axis=0)
        # Combine: emphasize both "large on average" and "consistently non-tiny"
        score = a * pi
    else:
        raise ValueError("method must be 'Eabs' or 'Eabs_stability'")

    # Decide how many to prune
    num_params = score.size
    k_prune = int(np.floor(sparsity * num_params))
    if k_prune == 0:
        return np.ones_like(score, dtype=float)

    flat = score.reshape(-1)

    if prune_smallest:
        # prune lowest scores
        thresh = np.partition(flat, k_prune - 1)[k_prune - 1]
        mask = (score > thresh).astype(float)
        # if ties create too many kept/pruned, fix deterministically
        # (rare but possible with many equal scores)
        if mask.sum() > num_params - k_prune:
            # drop some tied-at-threshold entries
            idx_tied = np.where(score.reshape(-1) == thresh)[0]
            need_drop = int(mask.sum() - (num_params - k_prune))
            if need_drop > 0:
                mask_flat = mask.reshape(-1)
                mask_flat[idx_tied[:need_drop]] = 0.0
                mask = mask_flat.reshape(score.shape)
        elif mask.sum() < num_params - k_prune:
            # add some tied entries if we kept too few
            idx_tied = np.where(score.reshape(-1) == thresh)[0]
            need_add = int((num_params - k_prune) - mask.sum())
            if need_add > 0:
                mask_flat = mask.reshape(-1)
                # add back from tied
                add_candidates = idx_tied[mask_flat[idx_tied] == 0.0]
                mask_flat[add_candidates[:need_add]] = 1.0
                mask = mask_flat.reshape(score.shape)
    else:
        # prune largest (not typical)
        thresh = np.partition(flat, num_params - k_prune)[num_params - k_prune]
        mask = (score < thresh).astype(float)

    return mask


def precompute_global_masks(
    all_fits,
    dataset_key,
    model,
    sparsity_levels,
    prune_W2=False,
    method="Eabs_stability"
):
    """
    Returns dict: sparsity -> (mask_W1, mask_W2 or None)
    """
    fit = all_fits[dataset_key][model]["posterior"]

    W1_samples = fit.stan_variable("W_1")  # (S, P, H)
    W2_samples = fit.stan_variable("W_L")  # (S, H, O) or (S, H) depending on O

    masks = {}
    for q in sparsity_levels:
        mask_W1 = build_global_mask_from_posterior(W1_samples, q, method=method)
        mask_W2 = None
        if prune_W2:
            mask_W2 = build_global_mask_from_posterior(W2_samples, q, method=method)
        masks[q] = (mask_W1, mask_W2)
    return masks


In [6]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from utils.generate_data import sample_gaussian_copula_uniform


def generate_Friedman_data_v2(N=200, D=10, sigma=1.0, test_size=0.2, seed=42, standardize_y=True, return_scale=True):
    np.random.seed(seed)
    X = np.random.uniform(0, 1, size=(N, D))
    x0, x1, x2, x3, x4 = X[:, 0], X[:, 1], X[:, 2], X[:, 3], X[:, 4]

    y_clean = (
        10 * np.sin(np.pi * x0 * x1) +
        20 * (x2 - 0.5) ** 2 +
        10 * x3 +
        5.0 * x4
    )
    y = y_clean + np.random.normal(0, sigma, size=N)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=seed)

    if not standardize_y:
        return (X_train, X_test, y_train, y_test) if not return_scale else (X_train, X_test, y_train, y_test, 0.0, 1.0)

    y_mean = y_train.mean()
    y_std = y_train.std() if y_train.std() > 0 else 1.0

    y_train_s = (y_train - y_mean) / y_std
    y_test_s = (y_test - y_mean) / y_std

    if return_scale:
        return X_train, X_test, y_train_s, y_test_s, y_mean, y_std
    return X_train, X_test, y_train_s, y_test_s

def generate_correlated_Friedman_data_v2(N=100, D=10, sigma=1.0, test_size=0.2, seed=42, standardize_y=True, return_scale=True):
    """
    Generate synthetic regression data for Bayesian neural network experiments.

    Parameters:
        N (int): Number of samples.
        D (int): Number of features.
        sigma (float): Noise level.
        test_size (float): Proportion for test split.
        seed (int): Random seed.
        standardize_y (bool): Whether to standardize the response variable.

    Returns:
        tuple: (X_train, X_test, y_train, y_test, y_mean, y_std) if standardize_y,
               else (X_train, X_test, y_train, y_test)
    """
    np.random.seed(seed)
    d = 10
    S_custom = np.eye(d)
    # Block 1 (vars 0..4): high Spearman, 0.7
    for i in range(0, 3):
        for j in range(i+1, 3):
            S_custom[i, j] = S_custom[j, i] = 0.8
    # Block 2 (vars 5..9): moderate Spearman, 0.4
    for i in range(5, 10):
        for j in range(i+1, 10):
            S_custom[i, j] = S_custom[j, i] = -0.5
    # Cross-block weaker, 0.15
    for i in range(0, 5):
        for j in range(5, 10):
            S_custom[i, j] = S_custom[j, i] = 0.15
    # A couple of bespoke pairs:
    S_custom[0, 9] = S_custom[9, 0] = 0.4
    S_custom[2, 7] = S_custom[7, 2] = 0.9  # very strong (will be projected if infeasible)
    S_custom[3, 4] = S_custom[4, 3] = -0.9  # very strong (will be projected if infeasible)
    S_custom[1, 6] = S_custom[6, 1] = -0.9  # very strong (will be projected if infeasible)

    U, _ = sample_gaussian_copula_uniform(n=10000, S=S_custom, random_state=123)
    #X = np.random.uniform(0, 1, size=(N, D))
    if N != U.shape[0]:
        idx = np.random.choice(U.shape[0], size=N, replace=False)
        X = U[idx, :]
    else:
        X = U

    x0, x1, x2, x3, x4 = X[:, 0], X[:, 1], X[:, 2], X[:, 3], X[:, 4]

    y_clean = (
        10 * np.sin(np.pi * x0 * x1) +
        20 * (x2 - 0.5) ** 2 +
        10 * x3 +
        5.0 * x4
    )

    y = y_clean + np.random.normal(0, sigma, size=N)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=seed)

    if not standardize_y:
        return (X_train, X_test, y_train, y_test) if not return_scale else (X_train, X_test, y_train, y_test, 0.0, 1.0)

    y_mean = y_train.mean()
    y_std = y_train.std() if y_train.std() > 0 else 1.0

    y_train_s = (y_train - y_mean) / y_std
    y_test_s = (y_test - y_mean) / y_std

    if return_scale:
        return X_train, X_test, y_train_s, y_test_s, y_mean, y_std
    return X_train, X_test, y_train_s, y_test_s

def make_large_eval_set(
    generator_fn,
    N_train,
    D,
    sigma,
    seed,
    n_eval=5000,
    standardize_y=True
):
    """
    Returns X_eval, y_eval (standardized if standardize_y=True), plus y_mean,y_std
    defined from the training split.
    """
    N_total = N_train + n_eval

    X_tr, X_te, y_tr, y_te, y_mean, y_std = generator_fn(
        N=N_total, D=D, sigma=sigma, test_size=n_eval / N_total, seed=seed,
        standardize_y=standardize_y, return_scale=True
    )
    # Now X_te has approx n_eval points (exact given test_size construction).
    return X_te, np.asarray(y_te).squeeze(), y_mean, y_std


def _logsumexp(a, axis=None):
    amax = np.max(a, axis=axis, keepdims=True)
    out = amax + np.log(np.sum(np.exp(a - amax), axis=axis, keepdims=True))
    return np.squeeze(out, axis=axis)

def gaussian_nll_pointwise(y, mu, sigma):
    return 0.5*np.log(2*np.pi*(sigma**2)) + 0.5*((y-mu)**2)/(sigma**2)

def compute_sparse_metrics_results_globalmask_large_eval(
    seeds, models, all_fits, get_N_sigma, forward_pass,
    folder,
    sparsity=0.0,
    masks_cache=None,
    prune_W2=False,
    compute_nll=True,
    noise_var_name="sigma",
    n_eval=5000,
    D=10,
    standardize_y=True,
    # pass the correct generator functions
    gen_uncorr=None,
    gen_corr=None,
):
    """
    Evaluate on a large generated test set instead of the stored tiny X_test/y_test.
    Assumes model was trained on standardized y if standardize_y=True.
    """
    assert gen_uncorr is not None and gen_corr is not None, "Pass both generator functions."

    results = []
    posterior_means = []

    # choose generator based on folder name
    def choose_gen(folder):
        return gen_corr if "friedman_correlated" in folder else gen_uncorr

    for seed in seeds:
        N, sigma = get_N_sigma(seed)
        dataset_key = f'Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}'

        # Build large eval set consistent with training split standardization
        gen_fn = choose_gen(folder)
        X_test, y_test, y_mean, y_std = make_large_eval_set(
            generator_fn=gen_fn,
            N_train=N,
            D=D,
            sigma=sigma,
            seed=seed,
            n_eval=n_eval,
            standardize_y=standardize_y
        )

        for model in models:
            try:
                fit = all_fits[dataset_key][model]['posterior']
                W1_samples = fit.stan_variable("W_1")           # (S, P, H)
                W2_samples = fit.stan_variable("W_L")           # (S, H, O)
                b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
                b2_samples = fit.stan_variable("output_bias")   # (S, O)

                noise_samples = None
                if compute_nll:
                    try:
                        noise_samples = fit.stan_variable(noise_var_name).squeeze()
                    except Exception:
                        noise_samples = None
            except KeyError:
                print(f"[SKIP] Model or posterior not found: {dataset_key} -> {model}")
                continue

            S = W1_samples.shape[0]
            y_hats = np.zeros((S, y_test.shape[0]))
            rmses = np.zeros(S)

            mask_W1 = mask_W2 = None
            if masks_cache is not None and sparsity > 0.0:
                mask_W1, mask_W2 = masks_cache[(dataset_key, model)][sparsity]

            for i in range(S):
                W1 = W1_samples[i]
                W2 = W2_samples[i]

                if mask_W1 is not None:
                    W1 = W1 * mask_W1
                if prune_W2 and (mask_W2 is not None):
                    W2 = W2 * mask_W2

                y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i]).squeeze()
                y_hats[i] = y_hat
                rmses[i] = np.sqrt(np.mean((y_hat - y_test)**2))

            # posterior mean RMSE (standardized scale)
            posterior_mean = y_hats.mean(axis=0)
            posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test)**2))

            out_pm = {
                'seed': seed,
                'N': N,
                'sigma': sigma,
                'model': model,
                'sparsity': sparsity,
                'n_eval': y_test.shape[0],
                'posterior_mean_rmse': posterior_mean_rmse,
                'posterior_mean_rmse_orig': posterior_mean_rmse * y_std,  # back to original y scale
            }

            if compute_nll:
                if noise_samples is None:
                    sig_s = np.ones(S)
                else:
                    sig_s = np.asarray(noise_samples).reshape(-1)[:S]

                # Expected NLL
                nll_draws = np.array([
                    gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i]).mean()
                    for i in range(S)
                ])
                expected_nll = nll_draws.mean()

                # Predictive (mixture) NLL
                loglik = -np.stack([
                    gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i])
                    for i in range(S)
                ], axis=0)  # (S, n_eval)
                lppd = (_logsumexp(loglik, axis=0) - np.log(S)).mean()
                predictive_nll = -lppd

                out_pm["expected_nll"] = expected_nll
                out_pm["predictive_nll"] = predictive_nll

                # Optional: predictive_nll on original scale (only if you also rescale sigma)
                # If your sigma posterior is on standardized scale, original sigma = sig_s * y_std.
                out_pm["predictive_nll_orig"] = predictive_nll + np.log(y_std)  # see note below

            posterior_means.append(out_pm)

            for i in range(S):
                row = {
                    'seed': seed,
                    'N': N,
                    'sigma': sigma,
                    'model': model,
                    'sparsity': sparsity,
                    'n_eval': y_test.shape[0],
                    'rmse': rmses[i],
                    'rmse_orig': rmses[i] * y_std
                }
                if compute_nll:
                    row["nll"] = gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i]).mean()
                results.append(row)

    return pd.DataFrame(results), pd.DataFrame(posterior_means)


In [7]:
from utils.sparsity import forward_pass_relu, forward_pass_tanh, local_prune_weights

def build_masks_cache_for_all(
    all_fits,
    dataset_keys,
    models,
    sparsity_levels,
    prune_W2=False,
    method="Eabs_stability"
):
    masks_cache = {}
    for dataset_key in dataset_keys:
        for model in models:
            try:
                masks_cache[(dataset_key, model)] = precompute_global_masks(
                    all_fits=all_fits,
                    dataset_key=dataset_key,
                    model=model,
                    sparsity_levels=sparsity_levels,
                    prune_W2=prune_W2,
                    method=method
                )
            except KeyError:
                print(f"[SKIP MASKS] Missing fit for {dataset_key} -> {model}")
    return masks_cache

seeds = [1]
seeds_correlated = [1]

sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]


def get_N_sigma(seed):
    if seed == 16:
        N=50
    elif seed == 1:
        N=100
    elif seed == 2:
        N=200
    else:
        N=500
    sigma=1.00
    return N, sigma

def get_N_sigma_correlated(seed):
    if seed == 16:
        N=50
    elif seed == 1:
        N=100
    elif seed == 6:
        N=200
    else:
        N=500
    sigma=1.00
    return N, sigma

# Build the list of dataset keys you actually evaluate (same keys as in your compute loop)
dataset_keys = []
for seed in seeds:
    N, sigma = get_N_sigma(seed)
    dataset_keys.append(f'Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}')

dataset_keys_corr = []
for seed in seeds_correlated:
    N, sigma = get_N_sigma_correlated(seed)
    dataset_keys_corr.append(f'Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}')


# Precompute masks once
masks_tanh_corr = build_masks_cache_for_all(tanh_fits_corr, dataset_keys_corr, model_names_tanh_corr, sparsity_levels, prune_W2=False)

df_rmse_tanh_corr, df_post_tanh_corr = {}, {}
for q in sparsity_levels:
    df_rmse_tanh_corr[q], df_post_tanh_corr[q] = compute_sparse_metrics_results_globalmask_large_eval(
        seeds=seeds_correlated,
        models=model_names_tanh_corr,
        all_fits=tanh_fits_corr,
        get_N_sigma=get_N_sigma_correlated,
        forward_pass=forward_pass_tanh,
        folder="friedman_correlated",
        sparsity=q,
        masks_cache=masks_tanh_corr,
        prune_W2=False,
        compute_nll=True,
        noise_var_name="sigma",
        n_eval=5000,
        D=10,
        standardize_y=True,
        gen_uncorr=generate_Friedman_data_v2,
        gen_corr=generate_correlated_Friedman_data_v2,
    )


In [8]:
df_post_tanh_full = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_post_tanh_corr.items()],
    ignore_index=True
)

In [None]:
df_post_tanh_full[df_post_tanh_full['sparsity']==0.0].sort_values(by="predictive_nll")

In [None]:
df_post_tanh_full[df_post_tanh_full['sparsity']==0.0].sort_values(by="posterior_mean_rmse")

In [None]:
plt.figure()
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Dirichlet Horseshoe tanh"]['predictive_nll'], label="DHS")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Dirichlet Horseshoe tanh nodewise"]['predictive_nll'], label="DHS - node")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Dirichlet Student T tanh"]['predictive_nll'], label="DST")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Dirichlet Student T tanh nodewise"]['predictive_nll'], label="DST - node")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Beta Horseshoe tanh"]['predictive_nll'], label="BHS")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Beta Horseshoe tanh nodewise"]['predictive_nll'], label="BHS - node")
plt.xlabel("Sparsity")
plt.ylabel("PLL")
plt.legend()
plt.grid()
plt.show()


In [None]:
plt.figure()
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Dirichlet Horseshoe tanh"]['posterior_mean_rmse'], label="DHS")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Dirichlet Horseshoe tanh nodewise"]['posterior_mean_rmse'], label="DHS - node")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Dirichlet Student T tanh"]['posterior_mean_rmse'], label="DST")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Dirichlet Student T tanh nodewise"]['posterior_mean_rmse'], label="DST - node")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Beta Horseshoe tanh"]['posterior_mean_rmse'], label="BHS")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Beta Horseshoe tanh nodewise"]['posterior_mean_rmse'], label="BHS - node")
plt.xlabel("Sparsity")
plt.ylabel("RMSE")
plt.legend()
plt.grid()
plt.show()