In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
data_dir = f"datasets/abalone"
results_dir_tanh = "results/regression/single_layer/tanh/abalone"
model_names_tanh = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh"]


full_config_path = "abalone_N3341_p8"


tanh_fit = get_model_fits(
    config=full_config_path,
    results_dir=results_dir_tanh,
    models=model_names_tanh,
    include_prior=False,
)

In [3]:
from utils.generate_data import load_abalone_regression_data
X, X_test, y, y_test = load_abalone_regression_data(standardized=False, frac=1.0)
# Coerce everything to plain float64 NumPy arrays
X      = np.asarray(X, dtype=float)
X_test = np.asarray(X_test, dtype=float)

# y often comes as a (n,1) DataFrame/array — flatten to (n,)
y      = np.asarray(y, dtype=float).reshape(-1)
y_test = np.asarray(y_test, dtype=float).reshape(-1)


In [4]:
import numpy as np
from numpy.linalg import cholesky, solve
from utils.kappa_matrix import shrinkage_matrix_stable

def build_operators_from_PS(P, S):
    """
    P, S: arrays of shape (S, d, d), SPD per sample.
    Returns:
      G        : P^{-1/2} S P^{-1/2}
      shrink_PS: (P+S)^{-1} S
      shrink_G : (I+G)^{-1} G
    """
    S_, d, _ = P.shape
    G         = np.empty_like(P, dtype=np.float64)
    shrink_PS = np.empty_like(P, dtype=np.float64)
    shrink_G  = np.empty_like(P, dtype=np.float64)

    I = np.eye(d)

    for s in range(S_):
        Ps = P[s]; Ss = S[s]

        # --- G = P^{-1/2} S P^{-1/2} via Cholesky (Ps = C C^T) -> C^{-T} S C^{-1}
        C = cholesky(Ps)            # upper-triangular by NumPy convention
        # temp = C^{-1}^T S
        temp = solve(C.T, Ss)#, assume_a='sym')    # solves C^T X = S  -> X = C^{-T} S
        Gs   = solve(C, temp.T)#, assume_a='sym').T  # solves C Y^T = temp^T -> Y = C^{-1} temp
        G[s] = Gs

        # # --- (P+S)^{-1} S
        Rs = shrinkage_matrix_stable(Ps, Ss)
        # A = Ps + Ss
        # L = cholesky(A)
        # # Solve A X = S  (two triangular solves)
        # Y = solve(L, Ss)#, lower=False)           # L X = S  (NumPy returns upper L; set lower=False)
        # X = solve(L.T, Y)#, lower=True)          # L^T X = Y
        # shrink_PS[s] = X
        shrink_PS[s] = np.eye(Ps.shape[0]) - Rs
        

        # --- (I+G)^{-1} G
        B = I + Gs
        LB = cholesky(B)
        YB = solve(LB, Gs)#, lower=False)
        XB = solve(LB.T, YB)#, lower=True)
        shrink_G[s] = XB

    return G, shrink_PS, shrink_G


# Example usage after reloading a saved NPZ:
dat = np.load("Abalone_matrices/Gaussian_PS.npz")
P_gauss, S_gauss = dat["P"].astype(np.float64), dat["S"].astype(np.float64)
G_gauss, shrink_PS_gauss, shrink_G_gauss = build_operators_from_PS(P_gauss, S_gauss)

dat = np.load("Abalone_matrices/Regularized_Horseshoe_PS.npz")
P_RHS, S_RHS = dat["P"].astype(np.float64), dat["S"].astype(np.float64)
G_RHS, shrink_PS_RHS, shrink_G_RHS = build_operators_from_PS(P_RHS, S_RHS)

dat = np.load("Abalone_matrices/Dirichlet_Horseshoe_PS.npz")
P_DHS, S_DHS = dat["P"].astype(np.float64), dat["S"].astype(np.float64)
G_DHS, shrink_PS_DHS, shrink_G_DHS = build_operators_from_PS(P_DHS, S_DHS)

dat = np.load("Abalone_matrices/Dirichlet_StudentT_PS.npz")
P_DST, S_DST = dat["P"].astype(np.float64), dat["S"].astype(np.float64)
G_DST, shrink_PS_DST, shrink_G_DST = build_operators_from_PS(P_DST, S_DST)

In [5]:
import numpy as np
from numpy.linalg import cholesky, solve, eigvalsh, norm

from utils.kappa_matrix import (
    build_hidden_and_jacobian_W, build_Sigma_y, build_S,
    build_P_from_lambda_tau, shrinkage_matrix_stable, extract_model_draws
)

# =========================
# Small utilities
# =========================

def solve_psd_pinv(S, g, rtol=1e-10):
    """ Moore–Penrose for PSD S times vector g. """
    evals, Q = np.linalg.eigh(S)
    tol = rtol * max(evals.max(), 1.0)
    keep = evals > tol
    if not np.any(keep):
        return np.zeros_like(g)
    inv_eigs = np.zeros_like(evals)
    inv_eigs[keep] = 1.0 / evals[keep]
    return Q @ (inv_eigs * (Q.T @ g))

def vec_w1(W1_draw, H, p, vec_order="unit-major"):
    """
    Return w1 vector consistent with how J_w columns are ordered.
    - 'unit-major' (default): weights are grouped by hidden unit (row-major flatten) -> shape (H,p) -> .reshape(-1)
    - 'feature-major': weights grouped by feature (Fortran/column-major) -> .reshape(-1, order='F')
    """
    W = np.asarray(W1_draw).reshape(H, p)
    if vec_order == "unit-major":
        return W.reshape(-1)  # C-order
    elif vec_order == "feature-major":
        return W.reshape(-1, order='F')
    else:
        raise ValueError("vec_order must be 'unit-major' or 'feature-major'")

# ----- Low-rank builder that mirrors build_Sigma_y exactly -----
def build_U(Phi_mat, tau_v, J_b1=None, J_b2=None, include_b1=True, include_b2=True):
    cols = [np.sqrt(tau_v**2) * Phi_mat]          # (n, H)
    if include_b1 and (J_b1 is not None):
        cols.append(J_b1)                          # (n, H)
    if include_b2 and (J_b2 is not None):
        cols.append(J_b2.reshape(-1, 1))           # (n, 1)
    return np.concatenate(cols, axis=1) if len(cols) > 1 else cols[0]  # (n, r)

# ----- Woodbury apply: returns Σ_y^{-1} B without forming Σ_y -----
def woodbury_apply(U, sigma2, B):
    # U: (n, r), B: (n,) or (n, k)
    n = U.shape[0]
    B = B.reshape(n, -1)  # (n, k)
    inv_sigma2 = 1.0 / sigma2
    UtU = U.T @ U                        # (r, r)
    A = np.eye(UtU.shape[0]) + inv_sigma2 * UtU
    RHS = inv_sigma2 * (U.T @ B)         # (r, k)
    X = np.linalg.solve(A, RHS)          # (r, k)
    out = inv_sigma2 * (B - U @ X)       # (n, k)
    return out if out.shape[1] > 1 else out.ravel()

import numpy as np
from numpy.linalg import LinAlgError, eigvalsh, norm, solve

# ---------- robust PSD→SPD helpers ----------
def _sym(A):
    return 0.5 * (A + A.T)

def _min_eig(A):
    try:
        return float(np.min(eigvalsh(_sym(A))))
    except LinAlgError:
        # fall back if eig fails due to NaNs
        return np.nan

def safe_cholesky(A, name="A", jitter0=None, max_tries=8, verbose=True):
    """
    Cholesky with automatic diagonal jitter. Returns (L, jitter_used).
    A must be symmetric (we symmetrize defensively).
    Jitter schedule: jitter0 * 10^k, k=0..max_tries-1
    Default jitter0 = 1e-12 * trace(A)/n, or 1e-12 if trace<=0.
    """
    A = _sym(np.asarray(A, float))
    n = A.shape[0]

    if jitter0 is None:
        tr = float(np.trace(A))
        jitter0 = 1e-12 * (tr / n if tr > 0 else 1.0)

    # Try without jitter first
    try:
        L = np.linalg.cholesky(A)
        return L, 0.0
    except LinAlgError:
        pass

    # Escalate jitter
    jitter = jitter0
    for k in range(max_tries):
        A_jit = A + jitter * np.eye(n)
        try:
            L = np.linalg.cholesky(A_jit)
            if verbose:
                mine = _min_eig(A)
                print(f"[safe_cholesky] {name}: added jitter={jitter:.2e} "
                      f"(min eig before={mine:.2e})")
            return L, jitter
        except LinAlgError:
            jitter *= 10.0

    # Final attempt with eigen clip (last resort)
    w, U = np.linalg.eigh(A)
    floor = max(1e-15, 1e-12 * np.max(w))
    w_clipped = np.clip(w, floor, None)
    A_proj = (U * w_clipped) @ U.T
    try:
        L = np.linalg.cholesky(A_proj)
        if verbose:
            print(f"[safe_cholesky] {name}: eigen-floor to {floor:.2e}")
        return L, -floor  # indicate eigen-floor used
    except LinAlgError as e:
        raise LinAlgError(f"Cholesky failed for {name} even after jitter & clip.") from e

# ---------- drop-in replacements for your helpers ----------
def cholesky_powers(P, name="P", verbose=True):
    """
    Returns (P_half, P_mhalf, jitter_used).
    P_half is lower-tri L s.t. P≈L L^T, P_mhalf≈L^{-T}.
    """
    L, jit = safe_cholesky(P, name=name, verbose=verbose)
    I = np.eye(P.shape[0])
    P_half = L
    P_mhalf = solve(L.T, I)  # L^T X = I -> X = L^{-T}
    return P_half, P_mhalf, jit

def shrink_from_PS(P, S, verbose=True):
    """(P+S)^{-1} S using robust Cholesky on (P+S)."""
    A = _sym(P + S)
    L, jit = safe_cholesky(A, name="P+S", verbose=verbose)
    Y = solve(L, _sym(S))
    return solve(L.T, Y)

def shrink_from_G(G, verbose=True):
    """(I+G)^{-1} G using robust Cholesky on (I+G)."""
    B = _sym(np.eye(G.shape[0]) + G)
    L, jit = safe_cholesky(B, name="I+G", verbose=verbose)
    Y = solve(L, _sym(G))
    return solve(L.T, Y)

def whiten_G_from_PS(P, S, verbose=True):
    """G = P^{-1/2} S P^{-1/2} with robust (approx) inverse sqrt via Cholesky."""
    # P_half = L, P_mhalf = L^{-T}
    P_half, P_mhalf, jit = cholesky_powers(_sym(P), name="P", verbose=verbose)
    # G = L^{-T} S L^{-1}  (we compute as solve(L, S) then solve(L.T, ...))
    X = solve(P_half, _sym(S))
    G = solve(P_half.T, X)
    return _sym(G)


# =========================
# Stepwise checks (one draw)
# =========================

def check_step_1_rebuild_S(
    X, y,
    W1_d, b1_d, W2_d, b2_d,
    sigma_d, tau_v_d,
    S_stored_d,
    include_b1_in_Sigma=True, include_b2_in_Sigma=True,
    activation="tanh",
    vec_order="unit-major",
    dense_crosscheck=False
):
    """
    Rebuild J_w and Σ_y^{-1} via Woodbury. Compare Ŝ = J^T Σ^{-1} J to stored S.
    Prints metrics and returns dict with errors and objects needed for later steps.
    """
    Phi, Jw, Jb1, Jb2 = build_hidden_and_jacobian_W(
        X, W1_d, b1_d, W2_d, activation=activation
    )

    H, p = W1_d.shape
    U = build_U(Phi, tau_v_d, J_b1=Jb1, J_b2=Jb2,
                include_b1=include_b1_in_Sigma, include_b2=include_b2_in_Sigma)

    # Σ_y^{-1} J_w
    inv_Sigma_J = woodbury_apply(U, sigma_d**2, Jw)  # (n, Hp)
    S_hat = Jw.T @ inv_Sigma_J                       # (Hp, Hp)

    # Compare
    S_d = S_stored_d
    rel_err = norm(S_hat - S_d, 'fro') / max(norm(S_d, 'fro'), 1e-16)
    max_abs = np.max(np.abs(S_hat - S_d))

    dense_rel = None
    if dense_crosscheck:
        Sigma_y = build_Sigma_y(
            Phi, tau_v=tau_v_d, noise=sigma_d,
            J_b1=Jb1, J_b2=Jb2,
            include_b1=include_b1_in_Sigma, include_b2=include_b2_in_Sigma
        )
        Xsol = np.linalg.solve(Sigma_y, Jw)
        S_dense = Jw.T @ Xsol
        dense_rel = norm(S_dense - S_d, 'fro') / max(norm(S_d, 'fro'), 1e-16)

    print(f"[Step1]  ‖Ŝ−S‖/‖S‖ = {rel_err:.3e}   |Ŝ−S|_∞ = {max_abs:.3e}"
          + ("" if dense_rel is None else f"   (dense={dense_rel:.3e})"))

    return dict(
        Phi=Phi, Jw=Jw, Jb1=Jb1, Jb2=Jb2, U=U,
        S_hat=S_hat, S_d=S_d,
        rel_err_S=rel_err, max_abs_S=max_abs,
        H=H, p=p
    )

def _sym(A):
    return 0.5*(A + A.T)

def spectral_shrink_from_G(G, clip_floor=0.0):
    """
    Return (I+G)^{-1} G computed spectrally as U diag(w/(1+w)) U^T.
    Optionally clip tiny negative eigenvalues (round-off) to 'clip_floor' (>=0).
    """
    w, U = np.linalg.eigh(_sym(G))
    if clip_floor is not None:
        w = np.clip(w, clip_floor, None)  # keep PSD
    s = w / (1.0 + w)                    # w/(1+w) is well-behaved at w≈0
    return (U * s) @ U.T

def check_step_2_shrinkage_identity(P_d, S_d, tol=1e-8):
    """
    Verify (P+S)^{-1}S == P^{-1/2} (I+G)^{-1}G P^{1/2} using spectral shrinkage for the G-branch.
    """
    # whiten
    P_d = _sym(P_d); S_d = _sym(S_d)
    # P^{-1/2} via Cholesky (with tiny jitter if needed)
    # (we reuse your safe Cholesky from earlier)
    P_half, P_mhalf, _ = cholesky_powers(P_d, name="P", verbose=False)
    # G = P^{-1/2} S P^{-1/2}
    G = _sym(P_mhalf @ S_d @ P_mhalf.T)

    # Original-basis shrinkage via (P+S)^{-1} S (triangular solves are fine here)
    shrink_PS = shrink_from_PS(P_d, S_d, verbose=False)

    # Whitened shrinkage via spectral formula (no solves / no eigen-floor on I+G)
    shrink_G_spec = spectral_shrink_from_G(G, clip_floor=0.0)

    # Map between bases
    shrink_from_G_in_orig = P_mhalf @ shrink_G_spec @ P_half           # P^{-1/2} * (...) * P^{1/2}
    shrink_PS_in_white    = P_half  @ shrink_PS      @ P_mhalf         # P^{1/2}  * (...) * P^{-1/2}

    ok_orig  = np.allclose(shrink_PS,          shrink_from_G_in_orig, rtol=tol, atol=1e-10)
    ok_white = np.allclose(shrink_PS_in_white, shrink_G_spec,         rtol=tol, atol=1e-10)

    diff_orig  = float(np.max(np.abs(shrink_PS - shrink_from_G_in_orig)))
    diff_white = float(np.max(np.abs(shrink_PS_in_white - shrink_G_spec)))

    print(f"[Step2]  shrink match (orig)={ok_orig}  max|Δ|={diff_orig:.2e}   "
          f"(white)={ok_white}  max|Δ|={diff_white:.2e}")

    return dict(
        shrink_PS=shrink_PS,
        shrink_G=shrink_G_spec,
        ok_orig=bool(ok_orig),  diff_orig=diff_orig,
        ok_white=bool(ok_white), diff_white=diff_white
    )


def check_step_3_barw_two_forms(
    y, Jw, U, sigma_d,
    P_d, S_d, shrink_PS=None,
    W1_d=None, b1_d=None, vec_order="unit-major"
):
    """
    Compare bar w via:
      A: (P+S)^{-1} g
      B: (I - (P+S)^{-1}S) P^{-1} g
    where g = J^T Σ^{-1} y*,  y* = y + J w0 + J_b1 b1
    """
    # y*
    if (W1_d is None) or (b1_d is None):
        raise ValueError("Provide W1_d and b1_d for y*.")
    H, p = W1_d.shape
    w0_vec = vec_w1(W1_d, H, p, vec_order=vec_order)
    y_star = y + (Jw @ w0_vec) + (U[:, -1]*0.0 if b1_d is None else 0.0)  # placeholder; J_b1 handled below

    # We need J_b1 for the shift, but it’s already absorbed in U when forming Σ^{-1} (no harm to add explicitly):
    # If you prefer explicit, pass Jb1 in and add (Jb1 @ b1_d) here. Otherwise, rely on y being the raw target.
    # Safer: accept Jb1 as None here and add it outside if needed.

    # g = J^T Σ_y^{-1} y*
    r = woodbury_apply(U, sigma_d**2, y_star)
    g = Jw.T @ r

    # Shrinkage pieces
    if shrink_PS is None:
        shrink_PS = shrink_from_PS(P_d, S_d)
    I = np.eye(P_d.shape[0])

    # A:
    barw_A = solve(P_d + S_d, g)

    # B:
    P_diag = np.diag(P_d)
    z = g / P_diag
    barw_B = (I - shrink_PS) @ z

    diff = np.max(np.abs(barw_A - barw_B))
    rel = norm(barw_A - barw_B) / max(norm(barw_A), 1e-16)

    print(f"[Step3]  barw forms:  max|Δ| = {diff:.3e}   rel = {rel:.3e}")
    return dict(barw_A=barw_A, barw_B=barw_B, max_abs=float(diff), rel=float(rel))

# =========================
# Full fast mean (many draws)
# =========================

def compute_linearized_mean_fast_fixed(
    X, y,
    W_1, b_1, W_2, b_2,
    noise_all, tau_v_all,
    P_all=None, S_all=None, shrink_PS_all=None,  # prefer supplying stored arrays
    lambda_all=None, tau_w_all=None,             # only used if P_all is None (fallback)
    activation="tanh",
    include_b1_in_Sigma=True,
    include_b2_in_Sigma=True,
    vec_order="unit-major",
    use_woodbury=True,
    return_mats=False,
    D_lim=None
):
    """
    Correct, basis-consistent, high-performance linearized mean.
    Prefers supplied P_all, S_all, shrink_PS_all (from your NPZs).
    """
    D, H, p = W_1.shape
    if D_lim is not None:
        D = D_lim
    N = H * p
    n = y.shape[0]
    y = np.asarray(y, float).reshape(n)

    w_bar_stack = np.empty((D, N))
    R_stack = np.empty((D, N, N)) if return_mats else None

    for d in range(D):
        Phi, Jw, Jb1, Jb2 = build_hidden_and_jacobian_W(
            X, W_1[d], b_1[d], W_2[d], activation=activation
        )
        U = build_U(
            Phi, tau_v_all[d],
            J_b1=Jb1, J_b2=Jb2,
            include_b1=include_b1_in_Sigma, include_b2=include_b2_in_Sigma
        )

        w0_vec = vec_w1(W_1[d], H, p, vec_order=vec_order)
        y_star = y + (Jw @ w0_vec) + (Jb1 @ b_1[d])

        if use_woodbury:
            r_vec = woodbury_apply(U, noise_all[d]**2, y_star)
        else:
            Sigma_y = build_Sigma_y(
                Phi, tau_v=tau_v_all[d], noise=noise_all[d],
                J_b1=Jb1, J_b2=Jb2,
                include_b1=include_b1_in_Sigma, include_b2=include_b2_in_Sigma
            )
            r_vec = solve(Sigma_y, y_star)

        g = Jw.T @ r_vec

        # P and S for this draw
        if P_all is not None:
            P_d = P_all[d]
            P_diag = np.diag(P_d)
        else:
            # Fallback: rebuild P_d (make sure this exactly matches your NPZ recipe!)
            if (lambda_all is None) or (tau_w_all is None):
                raise ValueError("Need P_all or (lambda_all, tau_w_all) to build P.")
            P_d = build_P_from_lambda_tau(lambda_all[d], tau_w=tau_w_all[d])
            P_diag = np.diag(P_d)

        if S_all is not None:
            S_d = S_all[d]
        else:
            # Fallback S: S = J^T Σ^{-1} J  (Woodbury path for Σ^{-1})
            inv_Sigma_J = woodbury_apply(U, noise_all[d]**2, Jw)
            S_d = Jw.T @ inv_Sigma_J

        # Shrinkage
        if shrink_PS_all is not None:
            shrink_PS_d = shrink_PS_all[d]
        else:
            shrink_PS_d = shrink_from_PS(P_d, S_d)

        # bar w
        z = g / P_diag
        bar_w = z - (shrink_PS_d @ z)  # (I - shrink_PS) z

        if return_mats:
            R_tmp = -shrink_PS_d.copy()
            np.fill_diagonal(R_tmp, 1.0 + np.diag(R_tmp))
            R_stack[d] = R_tmp

        w_bar_stack[d] = bar_w

    return (R_stack if return_mats else None), w_bar_stack

# =========================
# One-call driver for a draw
# =========================

def debug_linearization_once(
    X, y,
    W1_d, b1_d, W2_d, b2_d,
    sigma_d, tau_v_d,
    P_d, S_d,
    include_b1_in_Sigma=True, include_b2_in_Sigma=True,
    activation="tanh",
    vec_order="unit-major",
    dense_S_check=False
):
    """
    Run all three checks for a single draw index, printing compact metrics.
    Returns a dict with all intermediate pieces.
    """
    # Step 1: S reconstruction
    step1 = check_step_1_rebuild_S(
        X, y,
        W1_d, b1_d, W2_d, b2_d,
        sigma_d, tau_v_d, S_stored_d=S_d,
        include_b1_in_Sigma=include_b1_in_Sigma, include_b2_in_Sigma=include_b2_in_Sigma,
        activation=activation,
        vec_order=vec_order,
        dense_crosscheck=dense_S_check
    )

    # Step 2: shrinkage identity (with conjugation)
    step2 = check_step_2_shrinkage_identity(P_d, S_d, tol=1e-8)

    # Step 3: two forms of bar w
    # Reuse U and Jw from step1; pass shrink_PS from step2 path
    out3 = check_step_3_barw_two_forms(
        y, step1["Jw"], step1["U"], sigma_d,
        P_d, S_d, shrink_PS=shrink_from_PS(P_d, S_d),
        W1_d=W1_d, b1_d=b1_d, vec_order=vec_order
    )

    return dict(step1=step1, step2=step2, step3=out3)


In [6]:
W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(
    tanh_fit, model='Gaussian tanh'
)

_, w_bar_gauss = compute_linearized_mean_fast_fixed(
    X, y,
    W1, b1, W2, b2,
    sigma, tau_v,
    P_all=P_gauss, S_all=S_gauss,  # and shrink_PS_all=shrink_PS_DHS if you have it
    activation="tanh",
    include_b1_in_Sigma=True,
    include_b2_in_Sigma=True,
    vec_order="unit-major",
    use_woodbury=True,
    D_lim=None
)

W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(
    tanh_fit, model='Regularized Horseshoe tanh'
)

_, w_bar_RHS = compute_linearized_mean_fast_fixed(
    X, y,
    W1, b1, W2, b2,
    sigma, tau_v,
    P_all=P_RHS, S_all=S_RHS,  # and shrink_PS_all=shrink_PS_DHS if you have it
    activation="tanh",
    include_b1_in_Sigma=True,
    include_b2_in_Sigma=True,
    vec_order="unit-major",
    use_woodbury=True,
    D_lim=None
)

W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(
    tanh_fit, model='Dirichlet Horseshoe tanh'
)

_, w_bar_DHS = compute_linearized_mean_fast_fixed(
    X, y,
    W1, b1, W2, b2,
    sigma, tau_v,
    P_all=P_DHS, S_all=S_DHS,  # and shrink_PS_all=shrink_PS_DHS if you have it
    activation="tanh",
    include_b1_in_Sigma=True,
    include_b2_in_Sigma=True,
    vec_order="unit-major",
    use_woodbury=True,
    D_lim=None
)

W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(
    tanh_fit, model='Dirichlet Student T tanh'
)

_, w_bar_DST = compute_linearized_mean_fast_fixed(
    X, y,
    W1, b1, W2, b2,
    sigma, tau_v,
    P_all=P_DST, S_all=S_DST,  # and shrink_PS_all=shrink_PS_DHS if you have it
    activation="tanh",
    include_b1_in_Sigma=True,
    include_b2_in_Sigma=True,
    vec_order="unit-major",
    use_woodbury=True,
    D_lim=None
)


In [7]:
W_all_gauss = tanh_fit['Gaussian tanh']['posterior'].stan_variable("W_1")
v_all_gauss = tanh_fit['Gaussian tanh']['posterior'].stan_variable("W_L")

W_all_RHS = tanh_fit['Regularized Horseshoe tanh']['posterior'].stan_variable("W_1")
v_all_RHS = tanh_fit['Regularized Horseshoe tanh']['posterior'].stan_variable("W_L")

W_all_DHS = tanh_fit['Dirichlet Horseshoe tanh']['posterior'].stan_variable("W_1")
v_all_DHS = tanh_fit['Dirichlet Horseshoe tanh']['posterior'].stan_variable("W_L")

W_all_DST = tanh_fit['Dirichlet Student T tanh']['posterior'].stan_variable("W_1")
v_all_DST = tanh_fit['Dirichlet Student T tanh']['posterior'].stan_variable("W_L")

In [8]:
def align_and_compare(W_all, v_all, w_bar_stack, sort_key="abs_v"):
    """
    Align signs & permutations across draws before comparing linearized mean with posterior mean.

    Inputs
    ------
    W_all        : array-like, shape (D, H, p) or (D, p, H) or with stray singleton dims.
    v_all        : array-like, shape (D, H) or (D, H, 1) or similar (length H per draw).
    w_bar_stack  : array-like, shape (D, H*p) OR (D, H, p) OR (D, 1, H*p), etc.

    Returns
    -------
    W_fix        : (D, H, p)   sign/permutation aligned
    v_fix        : (D, H)
    wbar_fix     : (D, H, p)
    summary      : dict with RMSE, Corr, CosSim, SignAgree (means vs means in aligned basis)
    """
    import numpy as np

    W_all = np.asarray(W_all)
    v_all = np.asarray(v_all)
    w_bar_stack = np.asarray(w_bar_stack)

    D = W_all.shape[0]

    # --- infer H from v (source of truth) ---
    v0 = np.squeeze(v_all[0]).ravel()
    H = v0.size
    if H == 0:
        raise ValueError("v_all[0] seems empty; cannot infer H.")
    # infer p from w_bar_stack length
    wb0 = np.squeeze(w_bar_stack[0]).ravel()
    if wb0.size % H != 0:
        # fallback: try infer p from W_all[0] after squeezing
        W0 = np.squeeze(W_all[0])
        if W0.ndim != 2:
            # try to drop any singleton dims
            W0 = W0.reshape([s for s in W0.shape if s != 1])
        if W0.ndim != 2:
            raise ValueError(f"Cannot infer (H,p). v length={H}, but w_bar_stack[0] has {wb0.size} elems "
                             f"and W_all[0] has shape {np.squeeze(W_all[0]).shape}.")
        h, p_candidate = W0.shape
        if h != H and p_candidate == H:
            p = h
        else:
            p = p_candidate
    else:
        p = wb0.size // H

    N = H * p

    # alloc outputs
    W_fix = np.empty((D, H, p), dtype=float)
    v_fix = np.empty((D, H), dtype=float)
    wbar_fix = np.empty((D, H, p), dtype=float)

    def coerce_W(Wd, H, p):
        """Return Wd as (H,p). Accepts (H,p), (p,H), or with singleton dims."""
        A = np.asarray(Wd, dtype=float)
        A = np.squeeze(A)
        if A.ndim == 2:
            h, q = A.shape
            if h == H and q == p:
                return A
            if h == p and q == H:
                return A.T
            # If one matches H, try reshape to (H, -1)
            if h == H and h*q == H*p:
                return A.reshape(H, p)
            if q == H and h*q == H*p:
                return A.T.reshape(H, p)
            raise ValueError(f"Cannot coerce W of shape {A.shape} to (H,p)=({H},{p}).")
        elif A.ndim == 3 and 1 in A.shape:
            # squeeze singleton and recurse
            return coerce_W(np.squeeze(A), H, p)
        else:
            raise ValueError(f"Unexpected W ndim={A.ndim}, shape={A.shape}")

    def coerce_v(vd, H):
        """Return vd as (H,)"""
        v = np.asarray(vd, dtype=float).squeeze().ravel()
        if v.size != H:
            raise ValueError(f"v has size {v.size}, expected H={H}.")
        return v

    def coerce_wbar_row(wbd, H, p):
        """Return wbar row as (H,p) from (N,) or already (H,p)."""
        w = np.asarray(wbd, dtype=float).squeeze().ravel()
        if w.size == H * p:
            return w.reshape(H, p)
        # already 2D?
        W2 = np.asarray(wbd, dtype=float).squeeze()
        if W2.ndim == 2 and W2.shape == (H, p):
            return W2
        raise ValueError(f"w_bar row has {w.size} elems but H*p={H*p} and not (H,p).")

    for d in range(D):
        # coerce shapes
        Wd = coerce_W(W_all[d], H, p)          # (H,p)
        vd = coerce_v(v_all[d], H)             # (H,)
        wbd = coerce_wbar_row(w_bar_stack[d], H, p)

        # 1) sign fix so v >= 0
        s = np.sign(vd)
        s[s == 0.0] = 1.0
        Wd = Wd * s[:, None]
        wbd = wbd * s[:, None]
        vd = np.abs(vd)

        # 2) permute units by a stable key
        if sort_key == "abs_v":
            idx = np.argsort(-vd)  # descending |v|
        elif sort_key == "abs_v_times_rownorm":
            idx = np.argsort(-(vd * np.linalg.norm(Wd, axis=1)))
        else:
            raise ValueError(f"Unknown sort_key: {sort_key}")

        W_fix[d] = Wd[idx]
        wbar_fix[d] = wbd[idx]
        v_fix[d] = vd[idx]

    # Compare means in aligned basis
    w_post_mean = W_fix.reshape(D, -1).mean(axis=0)   # (N,)
    
    
    # before computing means
    nan_w = np.isnan(wbar_fix).sum()
    inf_w = np.isinf(wbar_fix).sum()
    if nan_w or inf_w:
        print(f"[warn] wbar_fix contains {nan_w} NaNs and {inf_w} Infs; using nanmean.")
        # You can also decide to drop offending draws instead of nanmean.

    w_lin_mean = np.nanmean(wbar_fix.reshape(D, -1), axis=0)   # ignore any remaining NaNs
    w_post_mean = np.nanmean(W_fix.reshape(D, -1), axis=0)
    
    rmse = float(np.sqrt(np.mean((w_lin_mean - w_post_mean)**2)))
    corr = float(np.corrcoef(w_lin_mean, w_post_mean)[0, 1])
    cos  = float(np.dot(w_lin_mean, w_post_mean) /
                 (np.linalg.norm(w_lin_mean) * np.linalg.norm(w_post_mean)))
    sign_agree = float(np.mean(np.sign(w_lin_mean) == np.sign(w_post_mean)))

    summary = dict(RMSE=rmse, Corr=corr, CosSim=cos, SignAgree=sign_agree,
                   H=H, p=p, N=N)
    return W_fix, v_fix, wbar_fix, summary


In [None]:
# --- Helper: pick a "MAP-like" representative draw and plot MAP vs. \bar{w} ---
import numpy as np
import matplotlib.pyplot as plt

def select_map_like_index(W_fix: np.ndarray) -> int:
    """
    Returns the index of the draw whose aligned W is closest (in Frobenius norm)
    to the aligned posterior mean -- a robust MAP/medoid proxy.
    W_fix: (D, H, p) aligned weights (output of align_and_compare)
    """
    D = W_fix.shape[0]
    mu = W_fix.reshape(D, -1).mean(axis=0)  # posterior mean in aligned basis
    diffs = W_fix.reshape(D, -1) - mu[None, :]
    d2 = np.einsum('di,di->d', diffs, diffs)  # squared distances
    return int(np.argmin(d2))

def plot_map_vs_barw(W_fix: np.ndarray, wbar_fix: np.ndarray, title: str = "", alpha=0.7):
    """
    Overlay scatter: MAP-like draw's W (dots) vs the same draw's \bar{w} (crosses).
    Both arrays must be aligned: (D, H, p). We auto-pick a representative draw.
    """
    D, H, p = W_fix.shape
    idx = select_map_like_index(W_fix)  # representative draw
    w_map = W_fix[idx].reshape(-1)
    w_bar = wbar_fix[idx].reshape(-1)
    
    eps = 1e-1                          # Small threshold to see non-zero weights

    x = np.arange(1, H*p + 1)
    plt.figure(figsize=(10, 3.5), dpi=150)
    plt.scatter(x, w_map, s=12, marker='o', label="MAP-like $w$", alpha=alpha)
    plt.scatter(x, w_bar, s=18, marker='x', label=r"Linearized $\bar{w}$", alpha=alpha)

    # light vertical guides between hidden units
    for h in range(1, H):
        plt.axvline(h*p + 0.5, color='0.85', lw=1, zorder=0)
    
    plt.axhline(eps, color='0.85', lw=1, zorder=0)
    plt.axhline(-eps, color='0.85', lw=1, zorder=0)

    plt.xlabel("parameter index (after alignment)")
    plt.ylabel("value")
    plt.title(title if title else "MAP-like $w$ vs linearized $\~w$")
    plt.legend()
    plt.tight_layout()
    plt.show()


In [None]:
# --- Gaussian: align and plot ---
W_fix_g, v_fix_g, wbar_fix_g, summary_g = align_and_compare(W_all_gauss, v_all_gauss, w_bar_gauss, sort_key="abs_v")
print("Gaussian summary:", summary_g)
plot_map_vs_barw(W_fix_g, wbar_fix_g, title="Gaussian prior: MAP-like $w$ vs linearized $\\bar{w}$")


In [None]:
# --- Regularized Horseshoe: align and plot ---
W_fix_r, v_fix_r, wbar_fix_r, summary_r = align_and_compare(W_all_RHS, v_all_RHS, w_bar_RHS, sort_key="abs_v")
print("RHS summary:", summary_r)
plot_map_vs_barw(W_fix_r, wbar_fix_r, title="RHS prior: MAP-like $w$ vs linearized $\\bar{w}$")


In [None]:
# --- Dirichlet Horseshoe & Dirichlet Student-t: align and plot ---
W_fix_dhs, v_fix_dhs, wbar_fix_dhs, summary_dhs = align_and_compare(W_all_DHS, v_all_DHS, w_bar_DHS, sort_key="abs_v")
print("DHS summary:", summary_dhs)
plot_map_vs_barw(W_fix_dhs, wbar_fix_dhs, title="DHS prior: MAP-like $w$ vs linearized $\\bar{w}$")


In [None]:

W_fix_dst, v_fix_dst, wbar_fix_dst, summary_dst = align_and_compare(W_all_DST, v_all_DST, w_bar_DST, sort_key="abs_v")
print("DST summary:", summary_dst)
plot_map_vs_barw(W_fix_dst, wbar_fix_dst, title="DST prior: MAP-like $w$ vs linearized $\\bar{w}$")


## DEBUG

In [None]:
# Load stored arrays for your model:
dat = np.load("Abalone_matrices/Dirichlet_Horseshoe_PS.npz")
P_DHS, S_DHS = dat["P"].astype(np.float64), dat["S"].astype(np.float64)
# (Optional) shrink_PS_DHS = dat["shrink_PS"].astype(np.float64)

W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(
    tanh_fit, model='Dirichlet Horseshoe tanh'
)

for d in (0, 17, 123, 777, 1023):
    print(f"\n=== Draw {d} ===")
    _ = debug_linearization_once(
        X, y,
        W1[d], b1[d], W2[d], b2[d],
        sigma[d], tau_v[d],
        P_DHS[d], S_DHS[d],
        include_b1_in_Sigma=True, include_b2_in_Sigma=True,
        activation="tanh",
        vec_order="unit-major",     # flip to 'feature-major' if Step1 errors are large
        dense_S_check=False
    )
