In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
import numpy as np

# params
n, pin, d = 200, 100, 8
H_star, r = 20, 5
active_scale, inactive_scale = 1.5, 0.1
x_noise, y_noise = 0.05, 0.3
rng = np.random.default_rng(42)

# orthonormal A (pin x d), A^T A = I_d
A_rand = rng.standard_normal((pin, d))
A, _ = np.linalg.qr(A_rand)

# latent Z and observed inputs X
Z = rng.standard_normal((n, d))
X = Z @ A.T + x_noise * rng.standard_normal((n, pin))

# post-activation feature map H = tanh(Z @ B^T) with anisotropy
B = rng.standard_normal((H_star, d))

# make first r rows orthonormal, scale active/inactive
Qb, _ = np.linalg.qr(B[:r, :].T)
B[:r, :] = Qb[:, :r].T
B[:r, :] *= active_scale
B[r:, :] *= inactive_scale

H = np.tanh(Z @ B.T)

# output weights supported on first r coords
w_star = np.zeros(H_star)
w_star[:r] = rng.standard_normal(r)

# targets
y = H @ w_star + y_noise * rng.standard_normal(n)

# simple train/val split indices
perm = rng.permutation(n)
n_tr = int(0.8 * n)
tr_idx, va_idx = perm[:n_tr], perm[n_tr:]

X_train, y_train = X[tr_idx], y[tr_idx]
X_test,   y_test   = X[va_idx], y[va_idx]

# quick shapes check
(X_train.shape, X_test.shape, y_train.shape, y_test.shape, H.shape)


In [None]:
results_dir_tanh = "results/ridgeless/alignment/priors"
model_names_tanh = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh"]

tanh_fit = get_model_fits(
    config="",
    results_dir=results_dir_tanh,
    models=model_names_tanh,
    include_prior=False,
)

In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error
from properscoring import crps_ensemble

results = []
rmse_per_model = {}   # stores per-draw RMSE arrays if you want to plot later
S = 4000              # cap draws if you want

for model_name, model_entry in tanh_fit.items():
    posterior = model_entry["posterior"]
    
    # Posterior predictive draws on test set: (S_all, n_test)
    y_draws = posterior.stan_variable("output_test").squeeze(-1)
    if y_draws.shape[0] > S:
        y_draws = y_draws[:S]

    # --- Per-draw RMSE (vectorized) ---
    # MSE per draw: mean over test points
    print(y_draws.shape, y_test.shape)
    mse_draws = ((y_draws - y_test[None, :])**2).mean(axis=1)
    rmse_draws = np.sqrt(mse_draws)  # shape (S,)
    rmse_per_model[model_name] = rmse_draws

    # --- Ensemble CRPS (proper scoring; lower is better) ---
    # properscoring expects forecasts shape (n_obs, n_members)
    crps_vec = crps_ensemble(y_test, y_draws.T)      # shape (n_test,)
    crps_mean = float(crps_vec.mean())

    results.append({
        "Model": model_name,
        "RMSE_per_draw_mean": float(rmse_draws.mean()),
        "RMSE_per_draw_median": float(np.median(rmse_draws)),
        "RMSE_per_draw_p25": float(np.percentile(rmse_draws, 25)),
        "RMSE_per_draw_p75": float(np.percentile(rmse_draws, 75)),
        "CRPS_ensemble_mean": crps_mean,
    })

results_df = pd.DataFrame(results)
print(results_df)


In [5]:
import numpy as np

def tanh_act(A): 
    return np.tanh(A)

def post_acts(X, W1, b1):
    # X: (n,p), W1: (H,p), b1: (H,)
    return tanh_act(X @ W1.T + b1[None,:])

def cov_spectrum(Z):
    Zc = Z - Z.mean(0, keepdims=True)
    Sig = (Zc.T @ Zc) / Zc.shape[0]
    lam, U = np.linalg.eigh(Sig)
    idx = lam.argsort()[::-1]
    return Sig, lam[idx], U[:, idx]

def ridge_fit(Z, y, lam=1e-3):
    Z1 = np.c_[np.ones((Z.shape[0],1)), Z]
    H = Z1.T @ Z1 + lam * Z1.shape[0] * np.eye(Z1.shape[1])
    theta = np.linalg.solve(H, Z1.T @ y)
    return theta  # [b, w]

def ridge_pred(Z, theta):
    Z1 = np.c_[np.ones((Z.shape[0],1)), Z]
    return Z1 @ theta

def min_norm_fit(Z, y):
    # Add bias column
    Z1 = np.c_[np.ones((Z.shape[0],1)), Z]   # (n, 1+H)
    theta = np.linalg.pinv(Z1) @ y          # min-norm solution
    return theta

def min_norm_pred(Z, theta):
    Z1 = np.c_[np.ones((Z.shape[0],1)), Z]
    return Z1 @ theta


def r2_score(y, yhat):
    ss_res = np.sum((y - yhat)**2)
    ss_tot = np.sum((y - y.mean())**2) + 1e-12
    return 1.0 - ss_res/ss_tot

def align_geom(WL, U, k):
    # fraction of ||WL||^2 captured by top-k eigenspace of Sigma_z
    WL = np.ravel(WL)  # (H,)
    num = np.linalg.norm(U[:, :k].T @ WL)**2
    den = np.linalg.norm(WL)**2 + 1e-12
    return float(num / den)

def align_signal(WL, lam, U, Sig, k=None):
    WL = np.ravel(WL)  # (H,)
    if k is None:
        num = ((WL @ U)**2 * lam).sum()
    else:
        num = ((WL @ U[:, :k])**2 * lam[:k]).sum()
    den = WL @ Sig @ WL + 1e-12
    return float(num / den)


def rks_padding(X_train, X_test, H_base, target_gamma, rng, act=np.tanh, scale=1.0):
    """
    Add random features so that p_aug / n_tr ≈ target_gamma,
    starting from H_base base features (e.g., your learned Z has H_base=16).
    """
    n_tr = X_train.shape[0]
    p_in = X_train.shape[1]
    p_target = int(np.ceil(target_gamma * n_tr))
    m_extra = max(0, p_target - H_base)
    if m_extra == 0:
        return np.zeros((X_train.shape[0],0)), np.zeros((X_test.shape[0],0))

    W = rng.standard_normal((m_extra, p_in)) * (scale / np.sqrt(p_in))
    b = rng.standard_normal(m_extra) * scale

    Z_extra_tr = act(X_train @ W.T + b[None, :])
    Z_extra_te = act(X_test  @ W.T + b[None, :])
    return Z_extra_tr, Z_extra_te

def orthogonal_noise_padding(Z_tr_base, Z_te_base, m_extra, rng):
    
    """
    Create m_extra nuisance features that are (train-)orthogonal to the columns of Z_tr_base.
    We generate Gaussian proposals and project out the Z span.
    The same projection is applied to test features to avoid leakage.
    """
    n_tr = Z_tr_base.shape[0]
    # projector onto span(Z_tr_base) in the sample space (rows = samples)
    # P = Z (Z^+), where Z^+ is pseudoinverse minimizing ||Z Z^+ - I||
    Z = Z_tr_base
    Z_pinv = np.linalg.pinv(Z)               # shape (H_base, n_tr)
    P = Z @ Z_pinv                           # (n_tr, n_tr)
    I = np.eye(n_tr)

    # train extra
    G_tr = rng.standard_normal((n_tr, m_extra))
    N_tr = (I - P) @ G_tr                    # residuals orthogonal to Z columns
    # test extra: project using the same Z-based map
    # Need to map test rows through the same column-space removal:
    # Build the linear map M such that for any vector g, residual = g - Z (Z^+ g)
    # For test: N_te = G_te - Z_te @ (Z_pinv @ G_tr_basis)? Instead, use the same random generator,
    # then remove the Z_te component with Z_pinv defined on train:
    G_te = rng.standard_normal((Z_te_base.shape[0], m_extra))
    N_te = G_te - Z_te_base @ (Z_pinv @ G_tr)  # matches the component directions removed on train

    # standardize columns for numerical stability
    def std_cols(A):
        s = A.std(axis=0, keepdims=True) + 1e-12
        return (A - A.mean(axis=0, keepdims=True)) / s

    return std_cols(N_tr), std_cols(N_te)


def _std_shapes(W1_s, b1_s, WL_s, p_in):
    # W1_s: (p,H) or (H,p) -> make (H,p)
    W1_s = np.asarray(W1_s)
    if W1_s.ndim != 2:
        W1_s = W1_s.squeeze()
    # if first dim equals p_in, it's (p,H) -> transpose
    if W1_s.shape[0] == p_in:
        W1_s = W1_s.T
    # b1_s: (1,H) or (H,) -> make (H,)
    b1_s = np.asarray(b1_s).squeeze().ravel()
    # WL_s: (H,1) or (1,H) or (H,) -> make (H,)
    WL_s = np.asarray(WL_s).squeeze().ravel()
    return W1_s, b1_s, WL_s


In [6]:
import numpy as np


def analyze_model(posterior, X_train, y_train, X_val, y_val, fit = "min_norm",
                  ks=(1,2,3,5,10), lam_ridge=1e-3, draws=20, thin=1):
    W1_all = posterior.stan_variable('W_1')          # (S, p, H) in your case
    b1_all = posterior.stan_variable('hidden_bias')  # (S, 1, H)
    WL_all = posterior.stan_variable('W_L')          # (S, H, 1)

    # handle possible 3D WL -> (S,H)
    if WL_all.ndim == 3:
        WL_all = WL_all.squeeze(-1)

    S_avail = W1_all.shape[0]
    S = min(draws, S_avail)
    idxs = np.arange(0, S*thin, thin)  # e.g., thin=5 -> 0,5,10,...

    p_in = X_train.shape[1]

    results = { 'varfrac@k': {k: [] for k in ks},
                'align_geom@k': {k: [] for k in ks},
                'align_signal@k': {k: [] for k in ks},
                'R2_full': [], 'R2_topk': {k: [] for k in ks} }

    for s in idxs:
        # ---- standardize shapes per draw ----
        W1_s, b1_s, WL_s = _std_shapes(W1_all[s], b1_all[s], WL_all[s], p_in)

        # ---- features on train/val ----
        Ztr = post_acts(X_train, W1_s, b1_s)   # (n_tr, H)
        Zva = post_acts(X_val,   W1_s, b1_s)   # (n_val, H)

        # ---- covariance & PCs ----
        Sig, lam, U = cov_spectrum(Ztr)
        vf = lam / (lam.sum() + 1e-12)

        # ---- ridge on all features ----
        if fit == "min_norm":
            theta = min_norm_fit(Ztr, y_train)
            yhat  = min_norm_pred(Zva, theta)
        else:    
            theta = ridge_fit(Ztr, y_train, lam=lam_ridge)
            yhat  = ridge_pred(Zva, theta)
        results['R2_full'].append(r2_score(y_val, yhat))

        # ---- top-k probes & alignment ----
        for k in ks:
            k = min(k, Ztr.shape[1])
            results['varfrac@k'][k].append(vf[:k].sum())
            results['align_geom@k'][k].append(align_geom(WL_s, U, k))
            results['align_signal@k'][k].append(align_signal(WL_s, lam, U, Sig, k=k))

            Ztr_k = Ztr @ U[:, :k]
            Zva_k = Zva @ U[:, :k]
            theta_k = ridge_fit(Ztr_k, y_train, lam=lam_ridge)
            yhat_k  = ridge_pred(Zva_k, theta_k)
            results['R2_topk'][k].append(r2_score(y_val, yhat_k))

    med = lambda a: float(np.median(np.asarray(a)))
    summary = {
        'R2_full_med': med(results['R2_full']),
        'varfrac@k_med': {k: med(v) for k,v in results['varfrac@k'].items()},
        'align_geom@k_med': {k: med(v) for k,v in results['align_geom@k'].items()},
        'align_signal@k_med': {k: med(v) for k,v in results['align_signal@k'].items()},
        'R2_topk_med': {k: med(v) for k,v in results['R2_topk'].items()},
    }
    return summary, results


In [None]:

post_gauss  = tanh_fit['Gaussian tanh']['posterior']
post_RHS = tanh_fit['Regularized Horseshoe tanh']['posterior']
post_DHS = tanh_fit['Dirichlet Horseshoe tanh']['posterior']
post_DST = tanh_fit['Dirichlet Student T tanh']['posterior']


ks = (1,2,3,5,10)
gauss_sum,  _ = analyze_model(post_gauss,  X_train, y_train, X_test, y_test, ks=ks, lam_ridge=1e-3, draws=20, thin=5)
RHS_sum, _ = analyze_model(post_RHS, X_train, y_train, X_test, y_test, ks=ks, lam_ridge=1e-3, draws=20, thin=5)
DHS_sum, _ = analyze_model(post_DHS, X_train, y_train, X_test, y_test, ks=ks, lam_ridge=1e-3, draws=20, thin=5)
DST_sum, _ = analyze_model(post_DST, X_train, y_train, X_test, y_test, ks=ks, lam_ridge=1e-3, draws=20, thin=5)

print('\n--- GAUSSIAN prior ---')
print('R2_full (med):', gauss_sum['R2_full_med'])
print('Top-k varfrac (med):', gauss_sum['varfrac@k_med'])
print('Align geom (med):', gauss_sum['align_geom@k_med'])
print('Align signal (med):', gauss_sum['align_signal@k_med'])
print('R2_topk (med):', gauss_sum['R2_topk_med'])

print('--- RHS prior ---')
print('R2_full (med):', RHS_sum['R2_full_med'])
print('Top-k varfrac (med):', RHS_sum['varfrac@k_med'])
print('Align geom (med):', RHS_sum['align_geom@k_med'])
print('Align signal (med):', RHS_sum['align_signal@k_med'])
print('R2_topk (med):', RHS_sum['R2_topk_med'])

print('--- DHS prior ---')
print('R2_full (med):', DHS_sum['R2_full_med'])
print('Top-k varfrac (med):', DHS_sum['varfrac@k_med'])
print('Align geom (med):', DHS_sum['align_geom@k_med'])
print('Align signal (med):', DHS_sum['align_signal@k_med'])
print('R2_topk (med):', DHS_sum['R2_topk_med'])

print('--- DST prior ---')
print('R2_full (med):', DST_sum['R2_full_med'])
print('Top-k varfrac (med):', DST_sum['varfrac@k_med'])
print('Align geom (med):', DST_sum['align_geom@k_med'])
print('Align signal (med):', DST_sum['align_signal@k_med'])
print('R2_topk (med):', DST_sum['R2_topk_med'])


In [6]:
import numpy as np

def analyze_intensive_augmented_model(
    posterior, X_train, y_train, X_val, y_val, 
    draws=20, thin=1,
    gamma_list=(0.9, 1.1, 1.25, 1.5, 2.0, 3.0, 5.0, 8.0),
    padding_mode='orth',          # 'orth' (recommended) or 'rks'
    rng=None,
    # stability & scaling knobs
    rcond=1e-10,                  # epsilon for lstsq
    standardize_extras=True,
    tail_scale=0.2,               # shrink extras so they stay in the covariance tail
    rks_act=np.tanh,
    rks_scale=0.1                 # scale for RKS weights/bias (effective ~ rks_scale/sqrt(p_in))
):
    """
    Ridgeless-only risk vs gamma by augmenting *hidden* features (post-activations).
    Padding is constructed to preserve signal in the leading eigenspace (extras live in the tail),
    so risk should decrease for large gamma in aligned settings.

    Returns:
        summary: {
            'gamma': [..],
            'rmse_val_med': [...],
            'rmse_tr_med':  [...],
            'p_aug_med':    [...]
        },
        results: dict with per-gamma lists over draws (for debugging/plots)
    """
    if rng is None:
        rng = np.random.default_rng(0)

    # --- pull posterior samples
    W1_all = posterior.stan_variable('W_1')          # (S, p, H) or (S, H, p)
    b1_all = posterior.stan_variable('hidden_bias')  # (S, 1, H) or (S, H)
    WL_all = posterior.stan_variable('W_L')          # (S, H, 1) or (S, H); not used here but retained for consistency

    if WL_all.ndim == 3:
        WL_all = WL_all.squeeze(-1)

    S_avail = W1_all.shape[0]
    S = min(draws, S_avail)
    idxs = np.arange(0, S*thin, thin)

    p_in = X_train.shape[1]
    n_tr = X_train.shape[0]
    gammas = list(gamma_list)

    # --------- inner helpers (kept local so we only 'change this function') ---------
    def _standardize_pair(A_tr, A_va):
        if not standardize_extras:
            return A_tr, A_va
        mu = A_tr.mean(axis=0, keepdims=True)
        sd = A_tr.std(axis=0, keepdims=True) + 1e-12
        return (A_tr - mu) / sd, (A_va - mu) / sd

    def _ridgeless_solve(Z, y):
        """
        Min-norm ridgeless: solve Z1 theta ≈ y (least-squares) with bias column included.
        lstsq returns the min-norm solution in underdetermined cases.
        """
        Z1 = np.c_[np.ones((Z.shape[0], 1)), Z]
        theta, *_ = np.linalg.lstsq(Z1, y, rcond=rcond)
        return theta

    def _predict(Z, theta):
        Z1 = np.c_[np.ones((Z.shape[0], 1)), Z]
        return Z1 @ theta

    def _rks_padding_from_X(H_base, g, Xtr, Xva):
        p_target = int(np.ceil(g * n_tr))
        m_extra = max(0, p_target - H_base)
        if m_extra == 0:
            return np.zeros((Xtr.shape[0], 0)), np.zeros((Xva.shape[0], 0)), H_base
        W = rng.standard_normal((m_extra, p_in)) * (rks_scale / np.sqrt(p_in))
        b = rng.standard_normal(m_extra) * rks_scale
        Z_extra_tr = rks_act(Xtr @ W.T + b[None, :])
        Z_extra_va = rks_act(Xva @ W.T + b[None, :])
        Z_extra_tr, Z_extra_va = _standardize_pair(Z_extra_tr, Z_extra_va)
        Z_extra_tr *= tail_scale
        Z_extra_va *= tail_scale
        return Z_extra_tr, Z_extra_va, H_base + m_extra

    def _orth_padding_from_Z(Ztr_base, Zva_base, g):
        """
        Create extras that are orthogonal (on train) to span(Ztr_base) in *sample space*,
        then standardize + tail-shrink. This keeps added columns in the covariance tail.
        """
        H_base = Ztr_base.shape[1]
        p_target = int(np.ceil(g * n_tr))
        m_extra = max(0, p_target - H_base)
        if m_extra == 0:
            return np.zeros((Ztr_base.shape[0], 0)), np.zeros((Zva_base.shape[0], 0)), H_base

        # projector onto span(Z) in sample space
        Z = Ztr_base
        Z_pinv = np.linalg.pinv(Z)        # (H_base, n_tr)
        P = Z @ Z_pinv                    # (n_tr, n_tr)
        I = np.eye(n_tr)

        # proposals and removal on train
        G_tr = rng.standard_normal((n_tr, m_extra))
        N_tr = (I - P) @ G_tr

        # map the *same directions* to validation
        G_va = rng.standard_normal((Zva_base.shape[0], m_extra))
        N_va = G_va - Zva_base @ (Z_pinv @ G_tr)

        # standardize + tail shrink
        N_tr, N_va = _standardize_pair(N_tr, N_va)
        N_tr *= tail_scale
        N_va *= tail_scale
        return N_tr, N_va, H_base + m_extra
    # -------------------------------------------------------------------------------

    # storage
    results = {
        'gamma': gammas,
        'rmse_val': {g: [] for g in gammas},
        'rmse_tr':  {g: [] for g in gammas},
        'p_aug':    {g: [] for g in gammas},
    }
    pred_val = {g: [] for g in gammas}

    # --------- main loop over posterior draws ---------
    for s in idxs:
        # shapes & base hidden features
        W1_s, b1_s, _ = _std_shapes(W1_all[s], b1_all[s], WL_all[s], p_in)
        Ztr = post_acts(X_train, W1_s, b1_s)   # (n_tr, H_base)
        Zva = post_acts(X_val,   W1_s, b1_s)   # (n_val, H_base)
        H_base = Ztr.shape[1]

        for g in gammas:
            # augment hidden features to reach target gamma
            if padding_mode == 'orth':
                Z_extra_tr, Z_extra_va, p_aug = _orth_padding_from_Z(Ztr, Zva, g)
            elif padding_mode == 'rks':
                Z_extra_tr, Z_extra_va, p_aug = _rks_padding_from_X(H_base, g, X_train, X_val)
            else:
                raise ValueError("padding_mode must be 'orth' or 'rks'")

            Z_aug_tr = np.c_[Ztr, Z_extra_tr]
            Z_aug_va = np.c_[Zva, Z_extra_va]

            # ridgeless min-norm on augmented hidden features
            theta = _ridgeless_solve(Z_aug_tr, y_train)
            yhat_tr = _predict(Z_aug_tr, theta)
            yhat_va = _predict(Z_aug_va, theta)
            pred_val[g].append(yhat_va)

            rmse_tr = float(np.sqrt(np.mean((y_train - yhat_tr)**2)))
            rmse_va = float(np.sqrt(np.mean((y_val   - yhat_va)**2)))

            results['rmse_tr'][g].append(rmse_tr)
            results['rmse_val'][g].append(rmse_va)
            results['p_aug'][g].append(p_aug)
            
    # ---- Compute bias, variance, risk per gamma ----
    bias = {}
    variance = {}
    risk = {}

    y_val = y_val.ravel()

    for g in gammas:
        if len(pred_val[g]) == 0:
            bias[g] = variance[g] = risk[g] = float('nan')
            continue

        Y = np.vstack(pred_val[g])    # (S, n_val)
        mean_pred = Y.mean(axis=0)    # (n_val,)

        bias[g] = float(((mean_pred - y_val)**2).mean())
        variance[g] = float(((Y - mean_pred)**2).mean())
        risk[g] = bias[g] + variance[g]

    # aggregate (median over draws)
    med = lambda a: float(np.median(np.asarray(a))) if len(a) else float('nan')

    summary = {
        'gamma': gammas,
        'rmse_val_med': [med(results['rmse_val'][g]) for g in gammas],
        'rmse_tr_med':  [med(results['rmse_tr'][g])  for g in gammas],
        'p_aug_med':    [med(results['p_aug'][g])    for g in gammas],
        'used_padding': padding_mode,
        'tail_scale': tail_scale
    }
    summary['bias_med'] = [bias[g] for g in gammas]
    summary['variance_med'] = [variance[g] for g in gammas]
    summary['risk_med'] = [risk[g] for g in gammas]

    return summary, results


In [22]:
import numpy as np

def analyze_intensive_augmented_model_2(
    posterior, X_train, y_train, X_val, y_val, 
    draws=20, thin=1,
    gamma_list=(0.9, 1.1, 1.25, 1.5, 2.0, 3.0, 5.0, 8.0),
    padding_mode='orth',          # 'orth' (recommended) or 'rks'
    rng=None,
    # stability & scaling knobs
    rcond=1e-10,                  # epsilon for lstsq
    standardize_extras=True,
    tail_scale=0.2,               # shrink extras so they stay in the covariance tail
    rks_act=np.tanh,
    rks_scale=0.1                 # scale for RKS weights/bias (effective ~ rks_scale/sqrt(p_in))
):
    """
    Ridgeless-only: augment hidden features to target gamma and solve min-norm LS.

    RETURNS
    -------
    summary : dict
        {
          'gamma': [...],
          'used_padding': 'orth' | 'rks',
          'tail_scale': float,
          # YOUR REQUESTED CURVE: one RMSE value per gamma = mean over draws
          'rmse_mean_val': [...],      # mean over draws of per-draw validation RMSE
          'rmse_mean_tr':  [...],      # mean over draws of per-draw train RMSE
          # Ensemble (across draws) decomposition curves:
          'bias_curve':     [...],
          'variance_curve': [...],
          'risk_curve':     [...]
        }

    results : dict
        {
          'pred_val': {gamma: [yhat_val_draw1, ..., yhat_val_drawS]},  # each (n_val,)
          'per_sample': {
              gamma: {
                  'rmse_tr': [...], 'rmse_val': [...], 'mse_val': [...], 'p_aug': [...]
              }, ...
          }
        }
    """
    if rng is None:
        rng = np.random.default_rng(0)

    # --- pull posterior samples
    W1_all = posterior.stan_variable('W_1')          # (S, p, H) or (S, H, p)
    b1_all = posterior.stan_variable('hidden_bias')  # (S, 1, H) or (S, H)
    WL_all = posterior.stan_variable('W_L')          # (S, H, 1) or (S, H); not used directly

    if WL_all.ndim == 3:
        WL_all = WL_all.squeeze(-1)

    S_avail = W1_all.shape[0]
    S = min(draws, S_avail)
    idxs = np.arange(0, S*thin, thin)

    p_in = X_train.shape[1]
    n_tr = X_train.shape[0]
    gammas = list(gamma_list)

    # --------- inner helpers ---------
    def _standardize_pair(A_tr, A_va):
        if not standardize_extras:
            return A_tr, A_va
        mu = A_tr.mean(axis=0, keepdims=True)
        sd = A_tr.std(axis=0, keepdims=True) + 1e-12
        return (A_tr - mu) / sd, (A_va - mu) / sd

    def _ridgeless_solve(Z, y):
        # Min-norm LS (with bias column)
        Z1 = np.c_[np.ones((Z.shape[0], 1)), Z]
        theta, *_ = np.linalg.lstsq(Z1, y, rcond=rcond)
        return theta

    def _predict(Z, theta):
        Z1 = np.c_[np.ones((Z.shape[0], 1)), Z]
        return Z1 @ theta

    def _rks_padding_from_X(H_base, g, Xtr, Xva):
        p_target = int(np.ceil(g * n_tr))
        m_extra = max(0, p_target - H_base)
        if m_extra == 0:
            return np.zeros((Xtr.shape[0], 0)), np.zeros((Xva.shape[0], 0)), H_base
        W = rng.standard_normal((m_extra, p_in)) * (rks_scale / np.sqrt(p_in))
        b = rng.standard_normal(m_extra) * rks_scale
        Z_extra_tr = rks_act(Xtr @ W.T + b[None, :])
        Z_extra_va = rks_act(Xva @ W.T + b[None, :])
        Z_extra_tr, Z_extra_va = _standardize_pair(Z_extra_tr, Z_extra_va)
        Z_extra_tr *= tail_scale
        Z_extra_va *= tail_scale
        return Z_extra_tr, Z_extra_va, H_base + m_extra

    def _orth_padding_from_Z(Ztr_base, Zva_base, g):
        """
        Create extras orthogonal (on train) to span(Ztr_base) in *sample space*,
        then standardize + tail-shrink.
        """
        H_base = Ztr_base.shape[1]
        p_target = int(np.ceil(g * n_tr))
        m_extra = max(0, p_target - H_base)
        if m_extra == 0:
            return np.zeros((Ztr_base.shape[0], 0)), np.zeros((Zva_base.shape[0], 0)), H_base

        Z = Ztr_base
        Z_pinv = np.linalg.pinv(Z)        # (H_base, n_tr)
        P = Z @ Z_pinv                    # (n_tr, n_tr)
        I = np.eye(n_tr)

        G_tr = rng.standard_normal((n_tr, m_extra))
        N_tr = (I - P) @ G_tr

        G_va = rng.standard_normal((Zva_base.shape[0], m_extra))
        N_va = G_va - Zva_base @ (Z_pinv @ G_tr)

        N_tr, N_va = _standardize_pair(N_tr, N_va)
        N_tr *= tail_scale
        N_va *= tail_scale
        return N_tr, N_va, H_base + m_extra
    # ---------------------------------

    # per-gamma storage, per posterior draw
    pred_val = {g: [] for g in gammas}  # list of yhat_va arrays, one per draw
    per_sample = {g: {'rmse_tr': [], 'rmse_val': [], 'mse_val': [], 'p_aug': []} for g in gammas}

    # --------- main loop over posterior draws ---------
    for s in range(1):#idxs:
        # shapes & base hidden features
        W1_s, b1_s, _ = _std_shapes(W1_all[s], b1_all[s], WL_all[s], p_in)
        Ztr = post_acts(X_train, W1_s, b1_s)   # (n_tr, H_base)
        Zva = post_acts(X_val,   W1_s, b1_s)   # (n_val, H_base)
        H_base = Ztr.shape[1]

        for g in gammas:
            # augment hidden features to reach target gamma
            if padding_mode == 'orth':
                Z_extra_tr, Z_extra_va, p_aug = _orth_padding_from_Z(Ztr, Zva, g)
            elif padding_mode == 'rks':
                Z_extra_tr, Z_extra_va, p_aug = _rks_padding_from_X(H_base, g, X_train, X_val)
            else:
                raise ValueError("padding_mode must be 'orth' or 'rks'")

            Z_aug_tr = np.c_[Ztr, Z_extra_tr]
            Z_aug_va = np.c_[Zva, Z_extra_va]

            # ridgeless min-norm on augmented hidden features
            theta   = _ridgeless_solve(Z_aug_tr, y_train)
            yhat_tr = _predict(Z_aug_tr, theta)
            yhat_va = _predict(Z_aug_va, theta)

            # store per-draw predictions (for ensemble bias/variance computation)
            pred_val[g].append(yhat_va)

            # per-draw metrics
            rmse_tr = float(np.sqrt(np.mean((y_train - yhat_tr)**2)))
            rmse_va = float(np.sqrt(np.mean((y_val   - yhat_va)**2)))
            mse_va  = float(np.mean((y_val - yhat_va)**2))   # per-draw MSE

            per_sample[g]['rmse_tr'].append(rmse_tr)
            per_sample[g]['rmse_val'].append(rmse_va)
            per_sample[g]['mse_val'].append(mse_va)
            per_sample[g]['p_aug'].append(p_aug)

    # --------- ensemble bias/variance/risk curves ----------
    y_val_vec = y_val.ravel()
    bias_curve = []
    variance_curve = []
    risk_curve = []
    rmse_mean_val = []  # << one RMSE value per gamma = mean over draws
    rmse_mean_tr  = []

    for g in gammas:
        # mean RMSE across draws (what you asked for)
        vals = per_sample[g]['rmse_val']
        trs  = per_sample[g]['rmse_tr']
        rmse_mean_val.append(float(np.mean(vals)) if len(vals) else float('nan'))
        rmse_mean_tr.append(float(np.mean(trs)) if len(trs) else float('nan'))

        # bias/variance/risk from ensemble predictions
        preds = pred_val[g]
        if len(preds) == 0:
            bias_curve.append(float('nan'))
            variance_curve.append(float('nan'))
            risk_curve.append(float('nan'))
            continue

        Y = np.vstack(preds)             # (S, n_val)
        mean_pred = Y.mean(axis=0)       # posterior-mean predictor on val
        bias_g = float(((mean_pred - y_val_vec)**2).mean())
        var_g  = float(((Y - mean_pred)**2).mean())
        bias_curve.append(bias_g)
        variance_curve.append(var_g)
        risk_curve.append(bias_g + var_g)

    summary = {
        'gamma': gammas,
        'used_padding': padding_mode,
        'tail_scale': tail_scale,
        'rmse_mean_val': rmse_mean_val,     # << plot THIS vs gamma (one value per gamma)
        'rmse_mean_tr':  rmse_mean_tr,
        'bias_curve': bias_curve,
        'variance_curve': variance_curve,
        'risk_curve': risk_curve
    }

    results = {
        'pred_val': pred_val,     # per gamma: list of yhat_val arrays, one per draw
        'per_sample': per_sample  # per gamma: dict with per-draw rmse_tr/rmse_val/mse_val/p_aug
    }

    return summary, results


In [8]:
post_gauss  = tanh_fit['Gaussian tanh']['posterior']
post_RHS = tanh_fit['Regularized Horseshoe tanh']['posterior']
post_DHS = tanh_fit['Dirichlet Horseshoe tanh']['posterior']
post_DST = tanh_fit['Dirichlet Student T tanh']['posterior']

In [24]:
summary_gauss, results_gauss = analyze_intensive_augmented_model_2(
    post_gauss, X_train, y_train, X_test, y_test,
    draws=20, thin=1,
    gamma_list=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.1, 1.5, 2.0, 3.0, 5.0, 7.0, 9.0, 11, 13, 15, 17, 19, 25, 50, 100),
    padding_mode='rks',          # 'orth' (recommended) or 'rks'
    rng=None,
    # stability & scaling knobs
    rcond=1e-10,                  # epsilon for lstsq
    standardize_extras=True,
    tail_scale=0.9,               # shrink extras so they stay in the covariance tail
    rks_act=np.tanh,
    rks_scale=0.5                 # scale for RKS weights/bias (effective ~ rks_scale/sqrt(p_in))
)

summary_RHS, results_RHS = analyze_intensive_augmented_model_2(
    post_RHS, X_train, y_train, X_test, y_test,
    draws=20, thin=1,
    gamma_list=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.1, 1.5, 2.0, 3.0, 5.0, 7.0, 9.0, 11, 13, 15, 17, 19, 25, 50, 100),
    padding_mode='rks',          # 'orth' (recommended) or 'rks'
    rng=None,
    # stability & scaling knobs
    rcond=1e-10,                  # epsilon for lstsq
    standardize_extras=True,
    tail_scale=0.9,               # shrink extras so they stay in the covariance tail
    rks_act=np.tanh,
    rks_scale=0.5                 # scale for RKS weights/bias (effective ~ rks_scale/sqrt(p_in))
)

summary_DHS, results_DHS = analyze_intensive_augmented_model_2(
    post_DHS, X_train, y_train, X_test, y_test,
    draws=20, thin=1,
    gamma_list=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.1, 1.5, 2.0, 3.0, 5.0, 7.0, 9.0, 11, 13, 15, 17, 19, 25, 50, 100),
    padding_mode='rks',          # 'orth' (recommended) or 'rks'
    rng=None,
    # stability & scaling knobs
    rcond=1e-10,                  # epsilon for lstsq
    standardize_extras=True,
    tail_scale=0.9,               # shrink extras so they stay in the covariance tail
    rks_act=np.tanh,
    rks_scale=0.5                 # scale for RKS weights/bias (effective ~ rks_scale/sqrt(p_in))
)

summary_DST, results_DST = analyze_intensive_augmented_model_2(
    post_DST, X_train, y_train, X_test, y_test,
    draws=20, thin=1,
    gamma_list=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.1, 1.5, 2.0, 3.0, 5.0, 7.0, 9.0, 11, 13, 15, 17, 19, 25, 50, 100),
    padding_mode='rks',          # 'orth' (recommended) or 'rks'
    rng=None,
    # stability & scaling knobs
    rcond=1e-10,                  # epsilon for lstsq
    standardize_extras=True,
    tail_scale=0.9,               # shrink extras so they stay in the covariance tail
    rks_act=np.tanh,
    rks_scale=0.5                 # scale for RKS weights/bias (effective ~ rks_scale/sqrt(p_in))
)

In [None]:
plt.figure()
plt.plot(summary_gauss['gamma'], summary_gauss['bias_curve'], label="Gauss")
plt.plot(summary_RHS['gamma'], summary_RHS['bias_curve'], label="RHS")
plt.plot(summary_DHS['gamma'], summary_DHS['bias_curve'], label="DHS")
plt.plot(summary_DST['gamma'], summary_DST['bias_curve'], label="DST")
plt.xlabel(f"$\gamma$")
plt.ylabel("Bias")
plt.legend()

In [None]:
plt.figure()
plt.plot(summary_gauss['gamma'], summary_gauss['variance_curve'], label="Gauss")
plt.plot(summary_RHS['gamma'], summary_RHS['variance_curve'], label="RHS")
plt.plot(summary_DHS['gamma'], summary_DHS['variance_curve'], label="DHS")
plt.plot(summary_DST['gamma'], summary_DST['variance_curve'], label="DST")
plt.xlabel(f"$\gamma$")
plt.ylabel("Variance")
plt.legend()

In [None]:
plt.figure()
plt.plot(summary_gauss['gamma'], summary_gauss['risk_curve'], label="Gauss")
plt.plot(summary_RHS['gamma'], summary_RHS['risk_curve'], label="RHS")
plt.plot(summary_DHS['gamma'], summary_DHS['risk_curve'], label="DHS")
plt.plot(summary_DST['gamma'], summary_DST['risk_curve'], label="DST")
plt.xlabel(f"$\gamma$")
plt.ylabel("Risk")
plt.legend()

In [None]:
plt.figure()
plt.plot(summary_gauss['gamma'], summary_gauss['rmse_mean_val'], label="Gauss")
plt.plot(summary_RHS['gamma'], summary_RHS['rmse_mean_val'], label="RHS")
plt.plot(summary_DHS['gamma'], summary_DHS['rmse_mean_val'], label="DHS")
plt.plot(summary_DST['gamma'], summary_DST['rmse_mean_val'], label="DST")
plt.xlabel(f"$\gamma$")
plt.ylabel("Test RMSE")
plt.legend()

In [None]:
summary_DST['rmse_mean_val']

In [None]:
plt.figure()
plt.plot(summary_gauss['gamma'], summary_gauss['rmse_mean_tr'], label="Gauss")
plt.plot(summary_RHS['gamma'], summary_RHS['rmse_mean_tr'], label="RHS")
plt.plot(summary_DHS['gamma'], summary_DHS['rmse_mean_tr'], label="DHS")
plt.plot(summary_DST['gamma'], summary_DST['rmse_mean_tr'], label="DST")
plt.xlabel(f"$\gamma$")
plt.ylabel("Train RMSE")
plt.legend()