In [None]:
import os
import numpy as np
from types import SimpleNamespace
from cmdstanpy import set_cmdstan_path

from utils.model_runner import run_regression_model
from utils.load_abalone import load_abalone_regression_data

# Set CmdStan path (edit if needed)
set_cmdstan_path("/Users/augustarnstad/.cmdstan/cmdstan-2.36.0")

## Run model

In [2]:
model_name   = "dirichlet_horseshoe"   # e.g. "horseshoe", "dirichlet_horseshoe", ...
overwrite    = False             
output_dir   = "results"         
standardize  = True              

# Data split / seed
frac = 0.2
seed = 42

# Model/run settings
H = 16
L = 1
burnin_samples = 1000
samples = 1000

In [None]:
X_train, X_test, y_train, y_test = load_abalone_regression_data(
    standardized=standardize,
    frac=frac
)

N, p = X_train.shape
data_type = "abalone"

config_name = f"abalone_N{N}_p{p}" + ("_standardized" if standardize else "")
model_output_dir = os.path.join(output_dir, model_name, config_name)

print("Config:", config_name)
print("Output:", model_output_dir)

In [None]:
if (not overwrite) and os.path.exists(model_output_dir):
    print(f"[Skip] Already completed: {config_name}")
else:
    print(f"[Run] Running model on: {config_name}")

    args = SimpleNamespace(
        N=N,
        p=p,
        sigma=None,
        data=data_type,
        standardize=standardize,
        test_shift=None,
        model=model_name,
        H=H,
        L=L,
        config=config_name,
        seed=seed,
        data_config="realworld",
        model_output_dir=model_output_dir,
        burnin_samples=burnin_samples,
        samples=samples,
        overwrite=overwrite, 
        output_dir=output_dir  
    )

    run_regression_model(
        model_name=model_name,
        config_name=config_name,
        X_train=X_train,
        X_test=X_test,
        y_train=y_train,
        y_test=y_test,
        args=args
    )

## Load models

In [None]:
data_dir = f"datasets/abalone"

results_dir_tanh = "results/regression/single_layer/tanh/abalone"

model_names_tanh = ["Dirichlet Horseshoe tanh", "Dirichlet Student T tanh", "Beta Horseshoe tanh", "Beta Student T tanh"]

from utils.model_loader import get_model_fits

full_config_path = "abalone_N3341_p8"

tanh_fit = get_model_fits(
    config=full_config_path,
    results_dir=results_dir_tanh,
    models=model_names_tanh,
    include_prior=False,
)

## Get results

In [None]:
from sklearn.metrics import mean_squared_error
from properscoring import crps_ensemble
import numpy as np
import pandas as pd

rows = []
for model_name, model_entry in tanh_fit.items():
    post = model_entry["posterior"]

    y_samps = post.stan_variable("output_test").squeeze(-1)

    y_mean = y_samps.mean(axis=0)                                   # (n_test,)
    rmse_post_mean = float(np.sqrt(mean_squared_error(y_test, y_mean)))

    per_draw_rmse = np.sqrt(((y_samps - y_test[None, :])**2).mean(axis=1))  # (S,)
    rmse_draw_mean = float(per_draw_rmse.mean())

    crps = float(np.mean(crps_ensemble(y_test, y_samps.T)))

    rows.append({
        "Model": model_name,
        "RMSE_posterior_mean": rmse_post_mean,
        "RMSE_mean_over_draws": rmse_draw_mean,
        "CRPS": crps,
        "n_draws": y_samps.shape[0]
    })

results_df = pd.DataFrame(rows).sort_values("RMSE_posterior_mean")
print(results_df)


## Prune per sample

In [None]:
def compute_sparse_rmse_results_abalone(models, all_fits, forward_pass, frac,
                         sparsity=0.0, prune_fn=None):
    results = []
    posterior_means = []
    for model in models:
        try:
            fit = all_fits[model]['posterior']
            W1_samples = fit.stan_variable("W_1")           # (S, P, H)
            W2_samples = fit.stan_variable("W_L")           # (S, H, O)
            b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
            b2_samples = fit.stan_variable("output_bias")   # (S, O)
        except KeyError:
            print(f"[SKIP] Model or posterior not found:")
            continue

        S = W1_samples.shape[0]
        rmses = np.zeros(S)
        #print(y_test.shape)
        _, X_test, _, y_test = load_abalone_regression_data(standardized=False, frac=frac)
        y_hats = np.zeros((S, y_test.shape[0]))

        for i in range(S):
            W1 = W1_samples[i]
            W2 = W2_samples[i]

            # Apply pruning mask if requested
            if prune_fn is not None and sparsity > 0.0:
                masks = prune_fn([W1, W2], sparsity)
                W1 = W1 * masks[0]
                #W2 = W2 * masks[1]

            y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i])
            y_hats[i] = y_hat.squeeze()  # Store the prediction for each sample
            rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test)**2))
            
        posterior_mean = np.mean(y_hats, axis=0)
        posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test.squeeze())**2))

        posterior_means.append({
            'model': model,
            'sparsity': sparsity,
            'posterior_mean_rmse': posterior_mean_rmse
        })

        for i in range(S):
            results.append({
                'model': model,
                'sparsity': sparsity,
                'rmse': rmses[i]
            })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse


In [None]:
from utils.sparsity import forward_pass_tanh, local_prune_weights

sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]

df_rmse_tanh, df_posterior_rmse_tanh = {}, {}

for sparsity in sparsity_levels:
    df_rmse_tanh[sparsity], df_posterior_rmse_tanh[sparsity] = compute_sparse_rmse_results_abalone(
        models = model_names_tanh,
        all_fits = tanh_fit, 
        forward_pass = forward_pass_tanh,
        frac=1.0,
        sparsity=sparsity, 
        prune_fn=local_prune_weights
    )


## Prune posterior

In [None]:
def build_global_mask_from_posterior(
    W_samples,
    sparsity,
    method="Eabs",          # "Eabs" or "Eabs_stability"
    stability_quantile=0.1, # used if method="Eabs_stability"
    prune_smallest=True
):
    """
    W_samples: array (S, ..., ...) posterior draws of a weight matrix.
    sparsity: fraction to prune (q). Keeps (1-q).
    Returns mask with same trailing shape as one draw, dtype float {0,1}.
    """
    assert 0.0 <= sparsity < 1.0
    S = W_samples.shape[0]
    W_abs = np.abs(W_samples)  # (S, ...)

    # Importance score a = E|w|
    a = W_abs.mean(axis=0)     # (..., ...)

    if method == "Eabs":
        score = a
    elif method == "Eabs_stability":
        # Stability proxy pi = P(|w| > t), where t is a small global quantile of |w|
        t = np.quantile(W_abs.reshape(S, -1), stability_quantile)
        pi = (W_abs > t).mean(axis=0)
        # Combine: emphasize both "large on average" and "consistently non-tiny"
        score = a * pi
    else:
        raise ValueError("method must be 'Eabs' or 'Eabs_stability'")

    # Decide how many to prune
    num_params = score.size
    k_prune = int(np.floor(sparsity * num_params))
    if k_prune == 0:
        return np.ones_like(score, dtype=float)

    flat = score.reshape(-1)

    if prune_smallest:
        # prune lowest scores
        thresh = np.partition(flat, k_prune - 1)[k_prune - 1]
        mask = (score > thresh).astype(float)
        # if ties create too many kept/pruned, fix deterministically
        # (rare but possible with many equal scores)
        if mask.sum() > num_params - k_prune:
            # drop some tied-at-threshold entries
            idx_tied = np.where(score.reshape(-1) == thresh)[0]
            need_drop = int(mask.sum() - (num_params - k_prune))
            if need_drop > 0:
                mask_flat = mask.reshape(-1)
                mask_flat[idx_tied[:need_drop]] = 0.0
                mask = mask_flat.reshape(score.shape)
        elif mask.sum() < num_params - k_prune:
            # add some tied entries if we kept too few
            idx_tied = np.where(score.reshape(-1) == thresh)[0]
            need_add = int((num_params - k_prune) - mask.sum())
            if need_add > 0:
                mask_flat = mask.reshape(-1)
                # add back from tied
                add_candidates = idx_tied[mask_flat[idx_tied] == 0.0]
                mask_flat[add_candidates[:need_add]] = 1.0
                mask = mask_flat.reshape(score.shape)
    else:
        # prune largest (not typical)
        thresh = np.partition(flat, num_params - k_prune)[num_params - k_prune]
        mask = (score < thresh).astype(float)

    return mask

def precompute_global_masks(
    all_fits,
    model,
    sparsity_levels,
    prune_W2=False,
    method="Eabs_stability"
):
    """
    Returns dict: sparsity -> (mask_W1, mask_W2 or None)
    """
    fit = all_fits[model]["posterior"]

    W1_samples = fit.stan_variable("W_1")  # (S, P, H)
    W2_samples = fit.stan_variable("W_L")  # (S, H, O) or (S, H) depending on O

    masks = {}
    for q in sparsity_levels:
        mask_W1 = build_global_mask_from_posterior(W1_samples, q, method=method)
        mask_W2 = None
        if prune_W2:
            mask_W2 = build_global_mask_from_posterior(W2_samples, q, method=method)
        masks[q] = (mask_W1, mask_W2)
    return masks


In [None]:
def _logsumexp(a, axis=None):
    amax = np.max(a, axis=axis, keepdims=True)
    out = amax + np.log(np.sum(np.exp(a - amax), axis=axis, keepdims=True))
    return np.squeeze(out, axis=axis)

def gaussian_nll_pointwise(y, mu, sigma):
    return 0.5*np.log(2*np.pi*(sigma**2)) + 0.5*((y-mu)**2)/(sigma**2)

def compute_sparse_metrics_results_globalmask(
    models, all_fits, forward_pass,
    sparsity=0.0,
    masks_cache=None,
    prune_W2=False,
    compute_nll=True,
    noise_var_name="sigma",
    frac = 1.0
):
    """
    Evaluate on a large generated test set instead of the stored tiny X_test/y_test.
    Assumes model was trained on standardized y if standardize_y=True.
    """
    posterior_means = []
    # Build large eval set consistent with training split standardization
    X_train, X_test, y_train, y_test = load_abalone_regression_data(standardized=False, frac=frac)


    for model in models:
        try:
            fit = all_fits[model]['posterior']
            W1_samples = fit.stan_variable("W_1")           # (S, P, H)
            W2_samples = fit.stan_variable("W_L")           # (S, H, O)
            b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
            b2_samples = fit.stan_variable("output_bias")   # (S, O)

            noise_samples = None
            if compute_nll:
                try:
                    noise_samples = fit.stan_variable(noise_var_name).squeeze()
                except Exception:
                    noise_samples = None
        except KeyError:
            print(f"[SKIP] Model or posterior not found: -> {model}")
            continue

        S = W1_samples.shape[0]
        y_hats = np.zeros((S, y_test.shape[0]))

        mask_W1 = mask_W2 = None
        if masks_cache is not None and sparsity > 0.0:
            mask_W1, mask_W2 = masks_cache[(model)][sparsity]

        for i in range(S):
            W1 = W1_samples[i]
            W2 = W2_samples[i]

            if mask_W1 is not None:
                W1 = W1 * mask_W1
            if prune_W2 and (mask_W2 is not None):
                W2 = W2 * mask_W2

            y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i]).squeeze()
            y_hats[i] = y_hat

        # posterior mean RMSE (standardized scale)
        posterior_mean = y_hats.mean(axis=0)
        posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test)**2))

        out_pm = {
            'N': X_train.shape[0],
            'model': model,
            'sparsity': sparsity,
            'n_eval': y_test.shape[0],
            'posterior_mean_rmse': posterior_mean_rmse,
        }

        if compute_nll:
            if noise_samples is None:
                sig_s = np.ones(S)
            else:
                sig_s = np.asarray(noise_samples).reshape(-1)[:S]

            # Expected NLL
            nll_draws = np.array([
                gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i]).mean()
                for i in range(S)
            ])
            expected_nll = nll_draws.mean()

            # Predictive (mixture) NLL
            loglik = -np.stack([
                gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i])
                for i in range(S)
            ], axis=0)  # (S, n_eval)
            lppd = (_logsumexp(loglik, axis=0) - np.log(S)).mean()
            predictive_nll = -lppd

            out_pm["expected_nll"] = expected_nll
            out_pm["predictive_nll"] = predictive_nll


        posterior_means.append(out_pm)

    return pd.DataFrame(posterior_means)


In [None]:
def build_masks_cache_for_all(
    all_fits,
    models,
    sparsity_levels,
    prune_W2=False,
    method="Eabs_stability"
):
    masks_cache = {}
    for model in models:
        try:
            masks_cache[(model)] = precompute_global_masks(
                all_fits=all_fits,
                model=model,
                sparsity_levels=sparsity_levels,
                prune_W2=prune_W2,
                method=method
            )
        except KeyError:
            print(f"[SKIP MASKS] Missing fit for -> {model}")
    return masks_cache


masks_tanh = build_masks_cache_for_all(tanh_fit, model_names_tanh, sparsity_levels, prune_W2=False)


In [None]:
df_post_tanh = {}

for q in sparsity_levels:
    df_post_tanh[q] = compute_sparse_metrics_results_globalmask(
        models=model_names_tanh,
        all_fits=tanh_fit,
        forward_pass=forward_pass_tanh,
        sparsity=q,
        masks_cache=masks_tanh,
        prune_W2=False,
        compute_nll=True,
        noise_var_name="sigma",
        frac=1.0
    )