In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_dir = f"datasets/abalone"
results_dir_relu = "results/regression/single_layer/relu/abalone"
results_dir_tanh = "results/regression/single_layer/tanh/abalone"

# model_names_relu = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]
# model_names_tanh = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh", "Beta Horseshoe tanh", "Beta Student T tanh"]
model_names_relu = ["Dirichlet Horseshoe", "Dirichlet Student T", "Beta Horseshoe", "Beta Student T"]
model_names_tanh = ["Dirichlet Horseshoe tanh", "Dirichlet Student T tanh", "Beta Horseshoe tanh", "Beta Student T tanh"]


full_config_path = "abalone_N3341_p8"
relu_fit = get_model_fits(
    config=full_config_path,
    results_dir=results_dir_relu,
    models=model_names_relu,
    include_prior=False,
)

tanh_fit = get_model_fits(
    config=full_config_path,
    results_dir=results_dir_tanh,
    models=model_names_tanh,
    include_prior=False,
)

## TEST NEW PRUNING SCHEME 

In [7]:
import numpy as np

def build_global_mask_from_posterior(
    W_samples,
    sparsity,
    method="Eabs",          # "Eabs" or "Eabs_stability"
    stability_quantile=0.1, # used if method="Eabs_stability"
    prune_smallest=True
):
    """
    W_samples: array (S, ..., ...) posterior draws of a weight matrix.
    sparsity: fraction to prune (q). Keeps (1-q).
    Returns mask with same trailing shape as one draw, dtype float {0,1}.
    """
    assert 0.0 <= sparsity < 1.0
    S = W_samples.shape[0]
    W_abs = np.abs(W_samples)  # (S, ...)

    # Importance score a = E|w|
    a = W_abs.mean(axis=0)     # (..., ...)

    if method == "Eabs":
        score = a
    elif method == "Eabs_stability":
        # Stability proxy pi = P(|w| > t), where t is a small global quantile of |w|
        t = np.quantile(W_abs.reshape(S, -1), stability_quantile)
        pi = (W_abs > t).mean(axis=0)
        # Combine: emphasize both "large on average" and "consistently non-tiny"
        score = a * pi
    else:
        raise ValueError("method must be 'Eabs' or 'Eabs_stability'")

    # Decide how many to prune
    num_params = score.size
    k_prune = int(np.floor(sparsity * num_params))
    if k_prune == 0:
        return np.ones_like(score, dtype=float)

    flat = score.reshape(-1)

    if prune_smallest:
        # prune lowest scores
        thresh = np.partition(flat, k_prune - 1)[k_prune - 1]
        mask = (score > thresh).astype(float)
        # if ties create too many kept/pruned, fix deterministically
        # (rare but possible with many equal scores)
        if mask.sum() > num_params - k_prune:
            # drop some tied-at-threshold entries
            idx_tied = np.where(score.reshape(-1) == thresh)[0]
            need_drop = int(mask.sum() - (num_params - k_prune))
            if need_drop > 0:
                mask_flat = mask.reshape(-1)
                mask_flat[idx_tied[:need_drop]] = 0.0
                mask = mask_flat.reshape(score.shape)
        elif mask.sum() < num_params - k_prune:
            # add some tied entries if we kept too few
            idx_tied = np.where(score.reshape(-1) == thresh)[0]
            need_add = int((num_params - k_prune) - mask.sum())
            if need_add > 0:
                mask_flat = mask.reshape(-1)
                # add back from tied
                add_candidates = idx_tied[mask_flat[idx_tied] == 0.0]
                mask_flat[add_candidates[:need_add]] = 1.0
                mask = mask_flat.reshape(score.shape)
    else:
        # prune largest (not typical)
        thresh = np.partition(flat, num_params - k_prune)[num_params - k_prune]
        mask = (score < thresh).astype(float)

    return mask


def precompute_global_masks(
    all_fits,
    model,
    sparsity_levels,
    prune_W2=False,
    method="Eabs_stability"
):
    """
    Returns dict: sparsity -> (mask_W1, mask_W2 or None)
    """
    fit = all_fits[model]["posterior"]

    W1_samples = fit.stan_variable("W_1")  # (S, P, H)
    W2_samples = fit.stan_variable("W_L")  # (S, H, O) or (S, H) depending on O

    masks = {}
    for q in sparsity_levels:
        mask_W1 = build_global_mask_from_posterior(W1_samples, q, method=method)
        mask_W2 = None
        if prune_W2:
            mask_W2 = build_global_mask_from_posterior(W2_samples, q, method=method)
        masks[q] = (mask_W1, mask_W2)
    return masks


In [33]:
from utils.generate_data import load_abalone_regression_data
def _logsumexp(a, axis=None):
    amax = np.max(a, axis=axis, keepdims=True)
    out = amax + np.log(np.sum(np.exp(a - amax), axis=axis, keepdims=True))
    return np.squeeze(out, axis=axis)

def gaussian_nll_pointwise(y, mu, sigma):
    return 0.5*np.log(2*np.pi*(sigma**2)) + 0.5*((y-mu)**2)/(sigma**2)

def compute_sparse_metrics_results_globalmask_large_eval(
    models, all_fits, forward_pass,
    sparsity=0.0,
    masks_cache=None,
    prune_W2=False,
    compute_nll=True,
    noise_var_name="sigma",
    # pass the correct generator functions
    gen_uncorr=None,
    gen_corr=None,
):
    """
    Evaluate on a large generated test set instead of the stored tiny X_test/y_test.
    Assumes model was trained on standardized y if standardize_y=True.
    """
    results = []
    posterior_means = []
    # Build large eval set consistent with training split standardization
    X_train, X_test, y_train, y_test = load_abalone_regression_data(standardized=False, frac=1.0)
    y_std = y_train.std()
    y_mean = y_train.mean()

    for model in models:
        try:
            fit = all_fits[model]['posterior']
            W1_samples = fit.stan_variable("W_1")           # (S, P, H)
            W2_samples = fit.stan_variable("W_L")           # (S, H, O)
            b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
            b2_samples = fit.stan_variable("output_bias")   # (S, O)

            noise_samples = None
            if compute_nll:
                try:
                    noise_samples = fit.stan_variable(noise_var_name).squeeze()
                except Exception:
                    noise_samples = None
        except KeyError:
            print(f"[SKIP] Model or posterior not found: -> {model}")
            continue

        S = W1_samples.shape[0]
        y_hats = np.zeros((S, y_test.shape[0]))
        rmses = np.zeros(S)

        mask_W1 = mask_W2 = None
        if masks_cache is not None and sparsity > 0.0:
            mask_W1, mask_W2 = masks_cache[(model)][sparsity]

        for i in range(S):
            W1 = W1_samples[i]
            W2 = W2_samples[i]

            if mask_W1 is not None:
                W1 = W1 * mask_W1
            if prune_W2 and (mask_W2 is not None):
                W2 = W2 * mask_W2

            y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i]).squeeze()
            y_hats[i] = y_hat
            rmses[i] = np.sqrt(np.mean((y_hat - y_test)**2))

        # posterior mean RMSE (standardized scale)
        posterior_mean = y_hats.mean(axis=0)
        posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test)**2))

        out_pm = {
            'N': X_train.shape[0],
            #'sigma': sigma,
            'model': model,
            'sparsity': sparsity,
            'n_eval': y_test.shape[0],
            'posterior_mean_rmse': posterior_mean_rmse,
            'posterior_mean_rmse_orig': posterior_mean_rmse * y_std,  # back to original y scale
        }

        if compute_nll:
            if noise_samples is None:
                sig_s = np.ones(S)
            else:
                sig_s = np.asarray(noise_samples).reshape(-1)[:S]

            # Expected NLL
            nll_draws = np.array([
                gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i]).mean()
                for i in range(S)
            ])
            expected_nll = nll_draws.mean()

            # Predictive (mixture) NLL
            loglik = -np.stack([
                gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i])
                for i in range(S)
            ], axis=0)  # (S, n_eval)
            lppd = (_logsumexp(loglik, axis=0) - np.log(S)).mean()
            predictive_nll = -lppd

            out_pm["expected_nll"] = expected_nll
            out_pm["predictive_nll"] = predictive_nll

            # Optional: predictive_nll on original scale (only if you also rescale sigma)
            # If your sigma posterior is on standardized scale, original sigma = sig_s * y_std.
            out_pm["predictive_nll_orig"] = predictive_nll + np.log(y_std)  # see note below

        posterior_means.append(out_pm)

        for i in range(S):
            row = {
                'N': X_train.shape[0],
                #'sigma': sigma,
                'model': model,
                'sparsity': sparsity,
                'n_eval': y_test.shape[0],
                'rmse': rmses[i],
                'rmse_orig': rmses[i] * y_std
            }
            if compute_nll:
                row["nll"] = gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i]).mean()
            results.append(row)

    return pd.DataFrame(results), pd.DataFrame(posterior_means)


In [39]:
from utils.sparsity import forward_pass_relu, forward_pass_tanh, local_prune_weights

def build_masks_cache_for_all(
    all_fits,
    models,
    sparsity_levels,
    prune_W2=False,
    method="Eabs_stability"
):
    masks_cache = {}
    for model in models:
        try:
            masks_cache[(model)] = precompute_global_masks(
                all_fits=all_fits,
                model=model,
                sparsity_levels=sparsity_levels,
                prune_W2=prune_W2,
                method=method
            )
        except KeyError:
            print(f"[SKIP MASKS] Missing fit for -> {model}")
    return masks_cache

sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]


# Build the list of dataset keys you actually evaluate (same keys as in your compute loop)

# Precompute masks once
masks_relu = build_masks_cache_for_all(relu_fit, model_names_relu, sparsity_levels, prune_W2=False)

masks_tanh = build_masks_cache_for_all(tanh_fit, model_names_tanh, sparsity_levels, prune_W2=False)


In [48]:
df_rmse_relu, df_post_relu = {}, {}

for q in sparsity_levels:
    df_rmse_relu[q], df_post_relu[q] = compute_sparse_metrics_results_globalmask_large_eval(
        models=model_names_relu,
        all_fits=relu_fit,
        forward_pass=forward_pass_relu,
        sparsity=q,
        masks_cache=masks_relu,
        prune_W2=False,
        compute_nll=True,
        noise_var_name="sigma",
    )

df_rmse_tanh, df_post_tanh = {}, {}

for q in sparsity_levels:
    df_rmse_tanh[q], df_post_tanh[q] = compute_sparse_metrics_results_globalmask_large_eval(
        models=model_names_tanh,
        all_fits=tanh_fit,
        forward_pass=forward_pass_tanh,
        sparsity=q,
        masks_cache=masks_tanh,
        prune_W2=False,
        compute_nll=True,
        noise_var_name="sigma",
    )

In [49]:
df_post_relu_full = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_post_relu.items()],
    ignore_index=True
)

df_post_tanh_full = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_post_tanh.items()],
    ignore_index=True
)

In [None]:
df_post_relu_full[df_post_relu_full['sparsity']==0.0].sort_values(by="posterior_mean_rmse")

In [None]:
df_post_relu_full[df_post_relu_full['sparsity']==0.0].sort_values(by="predictive_nll")

In [None]:
df_post_tanh_full[df_post_tanh_full['sparsity']==0.0].sort_values(by="posterior_mean_rmse")

In [None]:
df_post_tanh_full[df_post_tanh_full['sparsity']==0.0].sort_values(by="predictive_nll")

In [None]:
plt.figure()
plt.plot(sparsity_levels, df_post_relu_full[df_post_relu_full['model']=="Dirichlet Horseshoe"]['predictive_nll'], label="DHS")
plt.plot(sparsity_levels, df_post_relu_full[df_post_relu_full['model']=="Beta Horseshoe"]['predictive_nll'], label="BHS")
plt.plot(sparsity_levels, df_post_relu_full[df_post_relu_full['model']=="Dirichlet Student T"]['predictive_nll'], label="DST")
plt.plot(sparsity_levels, df_post_relu_full[df_post_relu_full['model']=="Beta Student T"]['predictive_nll'], label="BST")
plt.xlabel("Sparsity")
plt.ylabel("PLL")
plt.legend()
plt.grid()
plt.show()


In [None]:

plt.figure()
plt.plot(sparsity_levels, df_post_relu_full[df_post_relu_full['model']=="Dirichlet Horseshoe"]['posterior_mean_rmse'], label="DHS")
plt.plot(sparsity_levels, df_post_relu_full[df_post_relu_full['model']=="Beta Horseshoe"]['posterior_mean_rmse'], label="BHS")
plt.plot(sparsity_levels, df_post_relu_full[df_post_relu_full['model']=="Dirichlet Student T"]['posterior_mean_rmse'], label="DST")
plt.plot(sparsity_levels, df_post_relu_full[df_post_relu_full['model']=="Beta Student T"]['posterior_mean_rmse'], label="BST")
plt.xlabel("Sparsity")
plt.ylabel("RMSE")
plt.legend()
plt.grid()
plt.show()


In [None]:
plt.figure()
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Dirichlet Horseshoe tanh"]['predictive_nll'], label="DHS")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Beta Horseshoe tanh"]['predictive_nll'], label="BHS")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Dirichlet Student T tanh"]['predictive_nll'], label="DST")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Beta Student T tanh"]['predictive_nll'], label="BST")
plt.xlabel("Sparsity")
plt.ylabel("PLL")
plt.legend()
plt.grid()
plt.show()


In [None]:

plt.figure()
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Dirichlet Horseshoe tanh"]['posterior_mean_rmse'], label="DHS")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Beta Horseshoe tanh"]['posterior_mean_rmse'], label="BHS")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Dirichlet Student T tanh"]['posterior_mean_rmse'], label="DST")
plt.plot(sparsity_levels, df_post_tanh_full[df_post_tanh_full['model']=="Beta Student T tanh"]['posterior_mean_rmse'], label="BST")
plt.xlabel("Sparsity")
plt.ylabel("RMSE")
plt.legend()
plt.grid()
plt.show()
