In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
data_dir = f"datasets/abalone"
results_dir_tanh = "results/regression/single_layer/tanh/abalone"
model_names_tanh_nodewise = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh nodewise", "Dirichlet Student T tanh nodewise"]

config_path_small = "abalone_N334_p8"
config_path_medium = "abalone_N668_p8"
config_path_full = "abalone_N3341_p8"

tanh_fit_nodewise_small = get_model_fits(
    config=config_path_small,
    results_dir=results_dir_tanh,
    models=model_names_tanh_nodewise,
    include_prior=False,
)

tanh_fit_nodewise_medium = get_model_fits(
    config=config_path_medium,
    results_dir=results_dir_tanh,
    models=model_names_tanh_nodewise,
    include_prior=False,
)
tanh_fit_nodewise_full = get_model_fits(
    config=config_path_full,
    results_dir=results_dir_tanh,
    models=model_names_tanh_nodewise,
    include_prior=False,
)

## CRPS

In [25]:
import numpy as np
from properscoring import crps_ensemble
from utils.generate_data import load_abalone_regression_data
_, X_test_small, _, y_test_small = load_abalone_regression_data(standardized=False, frac=0.1)
_, X_test_small, _, y_test_medium = load_abalone_regression_data(standardized=False, frac=0.2)
_, X_test_small, _, y_test_full = load_abalone_regression_data(standardized=False, frac=1.0)

def random_subset(y, n_sub=None, seed=1):
    n_sub = y.shape[0] if n_sub is None else n_sub
    rng = np.random.default_rng(seed)
    idx = rng.choice(len(y), size=n_sub, replace=False)
    return idx

def compute_crps(models, all_fits, y_test):
    results = {}
    idx = random_subset(y_test)
    for model in models:
        result = {}
        try:
            fit = all_fits[model]['posterior']
            preds_raw = fit.stan_variable("output_test_rng")
        except KeyError:
            print(f"[SKIP] Model or posterior not found:")
            continue
        preds = np.asarray(preds_raw).squeeze(-1)
        preds_sub = preds[:, idx]
        #print(preds.shape)
        crps = crps_ensemble(y_test[idx], preds_sub.T)
        
        result = {
            #'model': model,
            'crps': crps,
            'y_test': y_test[idx],
            'preds': preds_sub.T
        }
        results[model] = result
    return results     
   
res_small = compute_crps(model_names_tanh_nodewise, tanh_fit_nodewise_small, y_test_small)
res_medium = compute_crps(model_names_tanh_nodewise, tanh_fit_nodewise_medium, y_test_medium)
res_full = compute_crps(model_names_tanh_nodewise, tanh_fit_nodewise_full, y_test_full)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

abbr = {
    "Gaussian tanh": "Gauss",
    "Regularized Horseshoe tanh": "RHS",
    "Dirichlet Horseshoe tanh nodewise": "DHS",
    "Dirichlet Student T tanh nodewise": "DST"
}

def plot_grouped_box_by_model(model_names, res_small, res_medium, res_full,
                              sizes=("0.1N", "0.2N", "N"),
                              figsize=(7, 2.8)):
    labels = [abbr.get(m, m) for m in model_names]
    res_list = [res_small, res_medium, res_full]

    x = np.arange(len(model_names)) + 1
    offsets = np.array([-0.22, 0.00, 0.22])
    width = 0.16
    size_colors = ["C0", "C1", "C2"]

    plt.figure(figsize=figsize)

    for j, (res_j, col) in enumerate(zip(res_list, size_colors)):
        data_j = [res_j[m]["crps"] for m in model_names]
        pos_j = x + offsets[j]

        bp = plt.boxplot(
            data_j,
            positions=pos_j,
            widths=width,
            patch_artist=True,
            showfliers=False
        )

        for box in bp["boxes"]:
            box.set_facecolor(col)
            box.set_alpha(0.55)
            box.set_edgecolor("black")
        for k in ["whiskers", "caps", "medians"]:
            for line in bp[k]:
                line.set_color("black")
                line.set_linewidth(1.0)

    plt.xticks(x, labels)
    plt.ylabel("CRPS")
    plt.grid(axis="y", alpha=0.3)

    # Proper legend
    legend_handles = [
        Patch(facecolor="C0", edgecolor="black", alpha=0.55, label=sizes[0]),
        Patch(facecolor="C1", edgecolor="black", alpha=0.55, label=sizes[1]),
        Patch(facecolor="C2", edgecolor="black", alpha=0.55, label=sizes[2]),
    ]
    plt.legend(handles=legend_handles, title="", frameon=False, loc="upper center")

    plt.tight_layout()
    plt.savefig("figures_for_use_in_paper/abalone_crps.pdf", bbox_inches="tight")
    plt.show()


plot_grouped_box_by_model(
    model_names_tanh_nodewise,
    res_small, res_medium, res_full,
    #sizes=("N=100", "N=200", "N=500"),
    figsize=(4, 3)
)


## RMSE and NLL

In [27]:
def _logsumexp(a, axis=None):
    amax = np.max(a, axis=axis, keepdims=True)
    out = amax + np.log(np.sum(np.exp(a - amax), axis=axis, keepdims=True))
    return np.squeeze(out, axis=axis)

def gaussian_nll_pointwise(y, mu, sigma):
    return 0.5*np.log(2*np.pi*(sigma**2)) + 0.5*((y-mu)**2)/(sigma**2)

def compute_rmse(
    models, all_fits, forward_pass,
    compute_nll=True,
    noise_var_name="sigma",
    frac = 1.0
):
    """
    Evaluate on a large generated test set instead of the stored tiny X_test/y_test.
    Assumes model was trained on standardized y if standardize_y=True.
    """
    posterior_means = []
    X_train, X_test, y_train, y_test = load_abalone_regression_data(standardized=False, frac=frac)

    for model in models:
        try:
            fit = all_fits[model]['posterior']
            W1_samples = fit.stan_variable("W_1")           # (S, P, H)
            W2_samples = fit.stan_variable("W_L")           # (S, H, O)
            b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
            b2_samples = fit.stan_variable("output_bias")   # (S, O)

            noise_samples = None
            if compute_nll:
                try:
                    noise_samples = fit.stan_variable(noise_var_name).squeeze()
                except Exception:
                    noise_samples = None
        except KeyError:
            print(f"[SKIP] Model or posterior not found: -> {model}")
            continue

        S = W1_samples.shape[0]
        y_hats = np.zeros((S, y_test.shape[0]))

        for i in range(S):
            W1 = W1_samples[i]
            W2 = W2_samples[i]
            y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i]).squeeze()
            y_hats[i] = y_hat

        # posterior mean RMSE (standardized scale)
        posterior_mean = y_hats.mean(axis=0)
        posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test)**2))

        out_pm = {
            'N': X_train.shape[0],
            'model': model,
            'n_eval': y_test.shape[0],
            'posterior_mean_rmse': posterior_mean_rmse,
        }

        if compute_nll:
            if noise_samples is None:
                sig_s = np.ones(S)
            else:
                sig_s = np.asarray(noise_samples).reshape(-1)[:S]

            # Expected NLL
            nll_draws = np.array([
                gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i]).mean()
                for i in range(S)
            ])
            expected_nll = nll_draws.mean()

            # Predictive (mixture) NLL
            loglik = -np.stack([
                gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i])
                for i in range(S)
            ], axis=0)  # (S, n_eval)
            lppd = (_logsumexp(loglik, axis=0) - np.log(S)).mean()
            predictive_nll = -lppd

            out_pm["expected_nll"] = expected_nll
            out_pm["predictive_nll"] = predictive_nll


        posterior_means.append(out_pm)

    return pd.DataFrame(posterior_means)


In [28]:
from utils.sparsity import forward_pass_tanh
rmse_nll_df_small = compute_rmse(
        models=model_names_tanh_nodewise,
        all_fits=tanh_fit_nodewise_small,
        forward_pass=forward_pass_tanh,
        compute_nll=True,
        noise_var_name="sigma",
        frac=0.1
    )

rmse_nll_df_medium = compute_rmse(
        models=model_names_tanh_nodewise,
        all_fits=tanh_fit_nodewise_medium,
        forward_pass=forward_pass_tanh,
        compute_nll=True,
        noise_var_name="sigma",
        frac=0.2
    )

rmse_nll_df_full = compute_rmse(
        models=model_names_tanh_nodewise,
        all_fits=tanh_fit_nodewise_full,
        forward_pass=forward_pass_tanh,
        compute_nll=True,
        noise_var_name="sigma",
        frac=1.0
    )

In [None]:
rmse_nll_df_full.round(3)

## SPARSITY

In [52]:
from utils.generate_data import load_abalone_regression_data
def compute_sparse_rmse_results_abalone(models, all_fits, forward_pass, frac,
                         sparsity=0.0, prune_fn=None):
    results = []
    posterior_means = []
    for model in models:
        try:
            fit = all_fits[model]['posterior']
            W1_samples = fit.stan_variable("W_1")           # (S, P, H)
            W2_samples = fit.stan_variable("W_L")           # (S, H, O)
            b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
            b2_samples = fit.stan_variable("output_bias")   # (S, O)
        except KeyError:
            print(f"[SKIP] Model or posterior not found:")
            continue

        S = W1_samples.shape[0]
        rmses = np.zeros(S)
        #print(y_test.shape)
        _, X_test, _, y_test = load_abalone_regression_data(standardized=False, frac=frac)
        y_hats = np.zeros((S, y_test.shape[0]))

        for i in range(S):
            W1 = W1_samples[i]
            W2 = W2_samples[i]

            # Apply pruning mask if requested
            if prune_fn is not None and sparsity > 0.0:
                masks = prune_fn([W1, W2], sparsity)
                W1 = W1 * masks[0]
                #W2 = W2 * masks[1]

            y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i])
            y_hats[i] = y_hat.squeeze()  # Store the prediction for each sample
            rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test)**2))
            
        posterior_mean = np.mean(y_hats, axis=0)
        posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test.squeeze())**2))

        posterior_means.append({
            'model': model,
            'sparsity': sparsity,
            'posterior_mean_rmse': posterior_mean_rmse
        })

        for i in range(S):
            results.append({
                'model': model,
                'sparsity': sparsity,
                'rmse': rmses[i]
            })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse


In [53]:
from utils.sparsity import forward_pass_relu, forward_pass_tanh, local_prune_weights

sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]

df_rmse_tanh_nodewise, df_posterior_rmse_tanh_nodewise = {}, {}

for sparsity in sparsity_levels:    
    df_rmse_tanh_nodewise[sparsity], df_posterior_rmse_tanh_nodewise[sparsity] = compute_sparse_rmse_results_abalone(
        models = model_names_tanh_nodewise,
        all_fits = tanh_fit_nodewise, 
        forward_pass = forward_pass_tanh,
        frac=1.0,
        sparsity=sparsity, 
        prune_fn=local_prune_weights
    )


In [54]:
df_post_tanh_full_nodewise_old = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_posterior_rmse_tanh_nodewise.items()],
    ignore_index=True
)

## TEST NEW PRUNING SCHEME 

In [6]:
import numpy as np

def build_global_mask_from_posterior(
    W_samples,
    sparsity,
    method="Eabs",          # "Eabs" or "Eabs_stability"
    stability_quantile=0.1, # used if method="Eabs_stability"
    prune_smallest=True
):
    """
    W_samples: array (S, ..., ...) posterior draws of a weight matrix.
    sparsity: fraction to prune (q). Keeps (1-q).
    Returns mask with same trailing shape as one draw, dtype float {0,1}.
    """
    assert 0.0 <= sparsity < 1.0
    S = W_samples.shape[0]
    W_abs = np.abs(W_samples)  # (S, ...)

    # Importance score a = E|w|
    a = W_abs.mean(axis=0)     # (..., ...)

    if method == "Eabs":
        score = a
    elif method == "Eabs_stability":
        # Stability proxy pi = P(|w| > t), where t is a small global quantile of |w|
        t = np.quantile(W_abs.reshape(S, -1), stability_quantile)
        pi = (W_abs > t).mean(axis=0)
        # Combine: emphasize both "large on average" and "consistently non-tiny"
        score = a * pi
    else:
        raise ValueError("method must be 'Eabs' or 'Eabs_stability'")

    # Decide how many to prune
    num_params = score.size
    k_prune = int(np.floor(sparsity * num_params))
    if k_prune == 0:
        return np.ones_like(score, dtype=float)

    flat = score.reshape(-1)

    if prune_smallest:
        # prune lowest scores
        thresh = np.partition(flat, k_prune - 1)[k_prune - 1]
        mask = (score > thresh).astype(float)
        # if ties create too many kept/pruned, fix deterministically
        # (rare but possible with many equal scores)
        if mask.sum() > num_params - k_prune:
            # drop some tied-at-threshold entries
            idx_tied = np.where(score.reshape(-1) == thresh)[0]
            need_drop = int(mask.sum() - (num_params - k_prune))
            if need_drop > 0:
                mask_flat = mask.reshape(-1)
                mask_flat[idx_tied[:need_drop]] = 0.0
                mask = mask_flat.reshape(score.shape)
        elif mask.sum() < num_params - k_prune:
            # add some tied entries if we kept too few
            idx_tied = np.where(score.reshape(-1) == thresh)[0]
            need_add = int((num_params - k_prune) - mask.sum())
            if need_add > 0:
                mask_flat = mask.reshape(-1)
                # add back from tied
                add_candidates = idx_tied[mask_flat[idx_tied] == 0.0]
                mask_flat[add_candidates[:need_add]] = 1.0
                mask = mask_flat.reshape(score.shape)
    else:
        # prune largest (not typical)
        thresh = np.partition(flat, num_params - k_prune)[num_params - k_prune]
        mask = (score < thresh).astype(float)

    return mask


def precompute_global_masks(
    all_fits,
    model,
    sparsity_levels,
    prune_W2=False,
    method="Eabs_stability"
):
    """
    Returns dict: sparsity -> (mask_W1, mask_W2 or None)
    """
    fit = all_fits[model]["posterior"]

    W1_samples = fit.stan_variable("W_1")  # (S, P, H)
    W2_samples = fit.stan_variable("W_L")  # (S, H, O) or (S, H) depending on O

    masks = {}
    for q in sparsity_levels:
        mask_W1 = build_global_mask_from_posterior(W1_samples, q, method=method)
        mask_W2 = None
        if prune_W2:
            mask_W2 = build_global_mask_from_posterior(W2_samples, q, method=method)
        masks[q] = (mask_W1, mask_W2)
    return masks


In [7]:
from utils.generate_data import load_abalone_regression_data
def _logsumexp(a, axis=None):
    amax = np.max(a, axis=axis, keepdims=True)
    out = amax + np.log(np.sum(np.exp(a - amax), axis=axis, keepdims=True))
    return np.squeeze(out, axis=axis)

def gaussian_nll_pointwise(y, mu, sigma):
    return 0.5*np.log(2*np.pi*(sigma**2)) + 0.5*((y-mu)**2)/(sigma**2)

def compute_sparse_metrics_results_globalmask_large_eval(
    models, all_fits, forward_pass,
    sparsity=0.0,
    masks_cache=None,
    prune_W2=False,
    compute_nll=True,
    noise_var_name="sigma",
    frac = 1.0
):
    """
    Evaluate on a large generated test set instead of the stored tiny X_test/y_test.
    Assumes model was trained on standardized y if standardize_y=True.
    """
    posterior_means = []
    # Build large eval set consistent with training split standardization
    X_train, X_test, y_train, y_test = load_abalone_regression_data(standardized=False, frac=frac)
    #y_std = y_train.std()
    #y_mean = y_train.mean()

    for model in models:
        try:
            fit = all_fits[model]['posterior']
            W1_samples = fit.stan_variable("W_1")           # (S, P, H)
            W2_samples = fit.stan_variable("W_L")           # (S, H, O)
            b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
            b2_samples = fit.stan_variable("output_bias")   # (S, O)

            noise_samples = None
            if compute_nll:
                try:
                    noise_samples = fit.stan_variable(noise_var_name).squeeze()
                except Exception:
                    noise_samples = None
        except KeyError:
            print(f"[SKIP] Model or posterior not found: -> {model}")
            continue

        S = W1_samples.shape[0]
        y_hats = np.zeros((S, y_test.shape[0]))

        mask_W1 = mask_W2 = None
        if masks_cache is not None and sparsity > 0.0:
            mask_W1, mask_W2 = masks_cache[(model)][sparsity]

        for i in range(S):
            W1 = W1_samples[i]
            W2 = W2_samples[i]

            if mask_W1 is not None:
                W1 = W1 * mask_W1
            if prune_W2 and (mask_W2 is not None):
                W2 = W2 * mask_W2

            y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i]).squeeze()
            y_hats[i] = y_hat

        # posterior mean RMSE (standardized scale)
        posterior_mean = y_hats.mean(axis=0)
        posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test)**2))

        out_pm = {
            'N': X_train.shape[0],
            'model': model,
            'sparsity': sparsity,
            'n_eval': y_test.shape[0],
            'posterior_mean_rmse': posterior_mean_rmse,
        }

        if compute_nll:
            if noise_samples is None:
                sig_s = np.ones(S)
            else:
                sig_s = np.asarray(noise_samples).reshape(-1)[:S]

            # Expected NLL
            nll_draws = np.array([
                gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i]).mean()
                for i in range(S)
            ])
            expected_nll = nll_draws.mean()

            # Predictive (mixture) NLL
            loglik = -np.stack([
                gaussian_nll_pointwise(y_test, y_hats[i], sig_s[i])
                for i in range(S)
            ], axis=0)  # (S, n_eval)
            lppd = (_logsumexp(loglik, axis=0) - np.log(S)).mean()
            predictive_nll = -lppd

            out_pm["expected_nll"] = expected_nll
            out_pm["predictive_nll"] = predictive_nll


        posterior_means.append(out_pm)

    return pd.DataFrame(posterior_means)


In [8]:
from utils.sparsity import forward_pass_relu, forward_pass_tanh, local_prune_weights

def build_masks_cache_for_all(
    all_fits,
    models,
    sparsity_levels,
    prune_W2=False,
    method="Eabs_stability"
):
    masks_cache = {}
    for model in models:
        try:
            masks_cache[(model)] = precompute_global_masks(
                all_fits=all_fits,
                model=model,
                sparsity_levels=sparsity_levels,
                prune_W2=prune_W2,
                method=method
            )
        except KeyError:
            print(f"[SKIP MASKS] Missing fit for -> {model}")
    return masks_cache

sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]

masks_tanh_nodewise = build_masks_cache_for_all(tanh_fit_nodewise, model_names_tanh_nodewise, sparsity_levels, prune_W2=False)


In [9]:
df_post_tanh_nodewise = {}

for q in sparsity_levels:
    df_post_tanh_nodewise[q] = compute_sparse_metrics_results_globalmask_large_eval(
        models=model_names_tanh_nodewise,
        all_fits=tanh_fit_nodewise,
        forward_pass=forward_pass_tanh,
        sparsity=q,
        masks_cache=masks_tanh_nodewise,
        prune_W2=False,
        compute_nll=True,
        noise_var_name="sigma",
        frac=1.0
    )

In [10]:
df_post_tanh_full_nodewise = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_post_tanh_nodewise.items()],
    ignore_index=True
)

In [None]:
df_post_tanh_full_nodewise[df_post_tanh_full_nodewise['sparsity']==0.0].round(3)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# -----------------------------
# USER SETTINGS
# -----------------------------
colors = {
    "Gaussian tanh": "C0",
    "Regularized Horseshoe tanh": "C1",
    "Dirichlet Horseshoe tanh": "C2",
    "Dirichlet Horseshoe tanh nodewise": "C2",
    "Dirichlet Student T tanh": "C3",
    "Dirichlet Student T tanh nodewise": "C3",
    "Beta Horseshoe tanh": "C4",
    "Beta Horseshoe tanh nodewise": "C4",
    "Beta Student T tanh": "C5",
    "Beta Student T tanh nodewise": "C5",
}

sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]
rmse_col = "posterior_mean_rmse"

Ns = [100, 200, 500]

# Provide your dataframes here
dfs = [(df_post_tanh_full_nodewise_old, "Prune per sample"), (df_post_tanh_full_nodewise, "Posterior prune")]

# models you want to skip (keep your list exactly as intended)
skip_models = {
    #"Gaussian tanh",
    #"Dirichlet Horseshoe tanh nodewise",
    #"Dirichlet Horseshoe tanh",
    #"Dirichlet Student T tanh nodewise",
    #"Dirichlet Student T tanh",
    # "Beta Horseshoe tanh",
    # "Beta Horseshoe tanh nodewise",
    # "Beta Student T tanh",
    # "Beta Student T tanh nodewise",
}

fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharex=True, sharey=True)

for ax, (df, title) in zip(axes, dfs):
    # filter + enforce sparsity ordering
    df_plot = df[df["sparsity"].isin(sparsity_levels)].copy()
    df_plot["sparsity"] = pd.Categorical(
        df_plot["sparsity"], categories=sparsity_levels, ordered=True
    )

    for model, g in df_plot.groupby("model", sort=False):
        if model in skip_models:
            continue

        g = g.sort_values("sparsity")

        # = "nodewise" in model
        ax.plot(
            g["sparsity"].astype(float),          # safe for categorical x
            g[rmse_col],
            marker="o", #v" if is_nodewise else "o",
            #linestyle="dashed" if is_nodewise else "-",
            color=colors.get(model, "C7"),
            label=model,
        )

    ax.set_title(title)
    ax.set_xlabel("sparsity")
    ax.set_xticks(sparsity_levels[:-1])
    ax.grid(True, linestyle="--", linewidth=0.5)
    ax.legend(loc="upper left", frameon=False)

axes[0].set_ylabel("RMSE")
plt.savefig("figures_for_use_in_paper/abalone_pruning_schemes.pdf", bbox_inches="tight")
plt.tight_layout()
plt.show()


In [49]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from utils.generate_data import load_abalone_regression_data
X_train, _, _, _ = load_abalone_regression_data(standardized=False, frac=1.0)

P = 8
H = 16
L = 1
out_nodes = 1

layer_structure = {
    'input_to_hidden': {'name': 'W_1', 'shape': (P, H)},
    'hidden_to_output': {'name': 'W_L', 'shape': (H, out_nodes)}
}


def build_single_draw_weights(fits, layer_structure, draw_idx):
    """Return {model: {'W_1': (P,H), 'W_L': (H,O)}} for ONE draw."""
    out = {}
    for name, fd in fits.items():
        fit = fd["posterior"]
        W1 = fit.stan_variable(layer_structure['input_to_hidden']['name'])[draw_idx]
        WL = fit.stan_variable(layer_structure['hidden_to_output']['name'])[draw_idx]
        WL = WL.reshape(layer_structure['hidden_to_output']['shape'])
        out[name] = {"W_1": W1, "W_L": WL}
    return out

def scale_W1_for_plot(model_means, mode='global'):
    """
    Skalerer alle W_1 til [-1, 1] for rettferdig sammenligning av edge-tykkelser.

    mode:
      - 'global' : én felles skala over alle modeller (mest sammenlignbar)
      - 'per_model': egen skala per modell (uavhengig sammenligning)
      - 'per_node' : skalerer hver kolonne (node) separat til [-1,1]

    Returnerer: scaled_model_means (samme struktur som input), scale_info
    """
    scaled = {}
    if mode == 'global':
        gmax = max(np.abs(m['W_1']).max() for m in model_means.values())
        gmax = max(gmax, 1e-12)
        for name, m in model_means.items():
            W1s = m['W_1'] / gmax
            out = {k: v for k, v in m.items()}
            out['W_1'] = W1s
            scaled[name] = out
        return scaled, {'mode': 'global', 'scale': gmax}

    elif mode == 'per_model':
        for name, m in model_means.items():
            s = max(np.abs(m['W_1']).max(), 1e-12)
            out = {k: v for k, v in m.items()}
            out['W_1'] = m['W_1'] / s
            scaled[name] = out
        return scaled, {'mode': 'per_model'}

    elif mode == 'per_node':
        for name, m in model_means.items():
            W1 = m['W_1'].copy()
            P, H = W1.shape
            for h in range(H):
                colmax = max(np.abs(W1[:, h]).max(), 1e-12)
                W1[:, h] = W1[:, h] / colmax
            out = {k: v for k, v in m.items()}
            out['W_1'] = W1
            scaled[name] = out
        return scaled, {'mode': 'per_node'}

    else:
        raise ValueError("mode must be 'global', 'per_model', or 'per_node'")
feature_names = list(X_train.columns)

abbr = {
    "Gaussian tanh": "Gauss",
    "Regularized Horseshoe tanh": "RHS",
    "Dirichlet Horseshoe tanh nodewise": "DHS",
    "Dirichlet Student T tanh nodewise": "DST",
    "Beta Horseshoe tanh": "BHS",
    "Beta Student T tanh": "BST",

}

def plot_models_with_activations(model_means, layer_sizes,
                                 activations=None, activation_color_max=None,
                                 ncols=3, figsize_per_plot=(5,4), signed_colors=False, feature_names=None):
    """
    model_means: dict {model_name: {'W_1':(P,H), 'W_L':(H,O), optional 'W_internal':[...]} }
    layer_sizes: f.eks [P, H, O] eller [P, H, H, O] ved internlag
    activations: dict {model_name: (H,)} – aktiveringsfrekvens kun for første skjulte lag
    activation_color_max: global maks for skalering av farger (hvis None brukes 1.0)
    """
    names = list(model_means.keys())
    n_models = len(names)
    nrows = int(np.ceil(n_models / ncols))
    figsize = (figsize_per_plot[0] * ncols, figsize_per_plot[1] * nrows)

    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    if nrows * ncols == 1:
        axes = np.array([axes])
    axes = axes.flatten()

    # Skru av blanke akser
    for ax in axes[n_models:]:
        ax.axis('off')

    for ax, name in zip(axes, names):
        weights = model_means[name]
        G = nx.DiGraph()
        pos, nodes_per_layer, node_colors = {}, [], []

        # Noder med posisjon og farge
        for li, size in enumerate(layer_sizes):
            ids = []
            ycoords = np.linspace(size - 1, 0, size) - (size - 1) / 2
            for i in range(size):
                nid = f"L{li}_{i}"
                G.add_node(nid)
                pos[nid] = (li, ycoords[i])
                ids.append(nid)
                if li == 0 and feature_names is not None:
                    ax.text(pos[nid][0]-0.12, pos[nid][1], feature_names[i],
                            ha='right', va='center', fontsize=8)

                if activations is not None and li == 1:  # kun første skjulte lag
                    #a = activations.get(name, np.zeros(size))
                    a = activations.get(name, np.zeros(size))
                    a = np.asarray(a).ravel()   # <-- flater til 1D array
                    scale = activation_color_max if activation_color_max is not None else 1.0
                    val = float(np.clip(a[i] / max(scale, 1e-12), 0.0, 1.0))
                    color = plt.cm.winter(val)
                else:
                    color = 'lightblue'
                node_colors.append(color)

            nodes_per_layer.append(ids)

        edge_colors, edge_widths = [], []

        def add_edges(W, inn, ut):
            for j, out_n in enumerate(ut):
                for i, in_n in enumerate(inn):
                    w = float(W[i, j])
                    G.add_edge(in_n, out_n, weight=abs(w))
                    edge_colors.append('red' if w >= 0 else 'blue')
                    edge_widths.append(abs(w))

        # input -> hidden(1)
        add_edges(weights['W_1'], nodes_per_layer[0], nodes_per_layer[1])

        # ev. internlag
        if 'W_internal' in weights:
            for l, Win in enumerate(weights['W_internal']):
                add_edges(Win, nodes_per_layer[l+1], nodes_per_layer[l+2])

        # siste hidden -> output
        add_edges(weights['W_L'], nodes_per_layer[-2], nodes_per_layer[-1])

        nx.draw(G, pos, ax=ax,
                node_color=node_colors,
                edge_color=(edge_colors if signed_colors else 'red'),
                width=[G[u][v]['weight'] for u,v in G.edges()],
                with_labels=False, node_size=400, arrows=False)

        ax.set_title(abbr[name], fontsize=10)
        ax.axis('off')

    plt.tight_layout()
    return fig

def compute_hidden_activation(fit_dict, x_train, draw_idx):
    fit = fit_dict['posterior']
    W1 = fit.stan_variable('W_1')[draw_idx, :, :]          # (P,H)
    try:
        b1 = fit.stan_variable('hidden_bias')[draw_idx, :] # (H,)
    except Exception:
        b1 = np.zeros(W1.shape[1])
    # tanh i [-1,1]
    a_full = np.tanh(x_train @ W1 + b1)             # (H,)
    a=np.mean(a_full, axis=0)
    return a


In [None]:
# Velg en observasjon å "lyse opp" nodefargene med
obs_idx = 3
draw_idx = 69 #pick_draw_idx(prior_fits, seed=42)      # one common draw across models
prior_draws = build_single_draw_weights(tanh_fit_nodewise_full, layer_structure, draw_idx)

# 1) Beregn aktivasjoner for ALLE modellene
activations = {}
for name, fd in tanh_fit_nodewise_full.items():
    a = compute_hidden_activation(fd, X_train, draw_idx)
    activations[name] = np.abs(a)      

# 2) Skaler vekter for plotting (som før)
scaled, _ = scale_W1_for_plot(prior_draws, mode='per_model')

# 3) Kall plottet med aktivasjoner
# Siden tanh ∈ [-1,1] og vi bruker |a|, så sett activation_color_max=1.0
fig = plot_models_with_activations(
    scaled,
    layer_sizes=[P, H, out_nodes],
    activations=None,
    activation_color_max=1.0,
    ncols=2,
    feature_names = None
)
plt.savefig("figures_for_use_in_paper/abalone_network_tanh.png", bbox_inches="tight")
plt.show()