In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
data_dir = f"datasets/abalone"
results_dir_relu = "results/regression/single_layer/relu/abalone"
results_dir_tanh = "results/regression/single_layer/tanh/abalone"
#model_names_relu = ["Dirichlet Student T"]
model_names_relu = ["Gaussian", "Regularized Horseshoe", "Dirichlet Horseshoe", "Dirichlet Student T"]
model_names_tanh = ["Gaussian tanh", "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh"]


full_config_path = "abalone_N3341_p8"
# relu_fit = get_model_fits(
#     config=full_config_path,
#     results_dir=results_dir_relu,
#     models=model_names_relu,
#     include_prior=False,
# )

tanh_fit = get_model_fits(
    config=full_config_path,
    results_dir=results_dir_tanh,
    models=model_names_tanh,
    include_prior=False,
)

In [3]:
from utils.generate_data import load_abalone_regression_data
X, X_test, y, y_test = load_abalone_regression_data(standardized=False, frac=1.0)
# Coerce everything to plain float64 NumPy arrays
X      = np.asarray(X, dtype=float)
X_test = np.asarray(X_test, dtype=float)

# y often comes as a (n,1) DataFrame/array — flatten to (n,)
y      = np.asarray(y, dtype=float).reshape(-1)
y_test = np.asarray(y_test, dtype=float).reshape(-1)


In [4]:
import numpy as np
from numpy.linalg import cholesky, solve
from utils.kappa_matrix import shrinkage_matrix_stable

def build_operators_from_PS(P, S):
    """
    P, S: arrays of shape (S, d, d), SPD per sample.
    Returns:
      G        : P^{-1/2} S P^{-1/2}
      shrink_PS: (P+S)^{-1} S
      shrink_G : (I+G)^{-1} G
    """
    S_, d, _ = P.shape
    G         = np.empty_like(P, dtype=np.float64)
    shrink_PS = np.empty_like(P, dtype=np.float64)
    shrink_G  = np.empty_like(P, dtype=np.float64)

    I = np.eye(d)

    for s in range(S_):
        Ps = P[s]; Ss = S[s]

        # --- G = P^{-1/2} S P^{-1/2} via Cholesky (Ps = C C^T) -> C^{-T} S C^{-1}
        C = cholesky(Ps)            # upper-triangular by NumPy convention
        # temp = C^{-1}^T S
        temp = solve(C.T, Ss)#, assume_a='sym')    # solves C^T X = S  -> X = C^{-T} S
        Gs   = solve(C, temp.T)#, assume_a='sym').T  # solves C Y^T = temp^T -> Y = C^{-1} temp
        G[s] = Gs

        # # --- (P+S)^{-1} S
        Rs = shrinkage_matrix_stable(Ps, Ss)
        # A = Ps + Ss
        # L = cholesky(A)
        # # Solve A X = S  (two triangular solves)
        # Y = solve(L, Ss)#, lower=False)           # L X = S  (NumPy returns upper L; set lower=False)
        # X = solve(L.T, Y)#, lower=True)          # L^T X = Y
        # shrink_PS[s] = X
        shrink_PS[s] = np.eye(Ps.shape[0]) - Rs
        

        # --- (I+G)^{-1} G
        B = I + Gs
        LB = cholesky(B)
        YB = solve(LB, Gs)#, lower=False)
        XB = solve(LB.T, YB)#, lower=True)
        shrink_G[s] = XB

    return G, shrink_PS, shrink_G


# Example usage after reloading a saved NPZ:
dat = np.load("Abalone_matrices/Gaussian_PS.npz")
P_gauss, S_gauss = dat["P"].astype(np.float64), dat["S"].astype(np.float64)
G_gauss, shrink_PS_gauss, shrink_G_gauss = build_operators_from_PS(P_gauss, S_gauss)

dat = np.load("Abalone_matrices/Regularized_Horseshoe_PS.npz")
P_RHS, S_RHS = dat["P"].astype(np.float64), dat["S"].astype(np.float64)
G_RHS, shrink_PS_RHS, shrink_G_RHS = build_operators_from_PS(P_RHS, S_RHS)

dat = np.load("Abalone_matrices/Dirichlet_Horseshoe_PS.npz")
P_DHS, S_DHS = dat["P"].astype(np.float64), dat["S"].astype(np.float64)
G_DHS, shrink_PS_DHS, shrink_G_DHS = build_operators_from_PS(P_DHS, S_DHS)

dat = np.load("Abalone_matrices/Dirichlet_StudentT_PS.npz")
P_DST, S_DST = dat["P"].astype(np.float64), dat["S"].astype(np.float64)
G_DST, shrink_PS_DST, shrink_G_DST = build_operators_from_PS(P_DST, S_DST)


In [5]:
from utils.kappa_matrix import visualize_models

matrices_S = [
    np.mean(S_gauss, axis=0),
    np.mean(S_RHS, axis=0),
    np.mean(S_DHS, axis=0),
    np.mean(S_DST, axis=0),
]
names_S = ["S (Gauss)", "S (RHS)", "S (DHS)", "S (DST)"]

matrices_G = [
    np.mean((G_gauss), axis=0),
    np.mean((G_RHS), axis=0),
    np.mean((G_DHS), axis=0),
    np.mean((G_DST), axis=0),
]

names_G = ["G (Gauss)", "G (RHS)", "G (DHS)", "G (DST)"]

matrices_shrink = [
    np.mean((shrink_G_gauss), axis=0),
    np.mean((shrink_G_RHS), axis=0),
    np.mean((shrink_G_DHS), axis=0),
    np.mean((shrink_G_DST), axis=0),
]

names_shrink = ["(I+G)^{-1}G (Gauss)", "(I+G)^{-1}G (RHS)", "(I+G)^{-1}G (DHS)", "(I+G)^{-1}G (DST)"]

matrices_operator = [
    np.mean((shrink_PS_gauss), axis=0),
    np.mean((shrink_PS_RHS), axis=0),
    np.mean((shrink_PS_DHS), axis=0),
    np.mean((shrink_PS_DST), axis=0),
]

names_operator = ["(P+S)^{-1}S (Gauss)", "(P+S)^{-1}S (RHS)", "(P+S)^{-1}S (DHS)", "(P+S)^{-1}S (DST)"]

In [None]:
128*128

In [None]:
np.sum(np.abs(matrices_S[0])>1e-4)/(128*128)

In [None]:
visualize_models(matrices_S, names_S, H=16, p=8, use_abs=False)

In [None]:
visualize_models(matrices_G, names_G, H=16, p=8, use_abs=False)#, cmap="magma")


In [None]:
visualize_models(matrices_shrink, names_shrink, H=16, p=8, use_abs=False)#, cmap="magma")

In [None]:
visualize_models(matrices_operator, names_operator, H=16, p=8, use_abs=False)#, cmap="magma")

In [10]:
# --- Traces as distributions (df_eff = tr(R) vs total shrinkage = tr(I-R)) ---
import matplotlib.pyplot as plt

# Effective dof: trace of (I+G)^{-1}G per draw
tr_R_gauss = np.trace(shrink_G_gauss, axis1=1, axis2=2)
tr_R_RHS   = np.trace(shrink_G_RHS,   axis1=1, axis2=2)
tr_R_DHS   = np.trace(shrink_G_DHS,   axis1=1, axis2=2)
tr_R_DST   = np.trace(shrink_G_DST,   axis1=1, axis2=2)

# If you also want “total shrinkage”, use your SP_inv_S_* stacks (I - R):
tr_SPinvS_gauss = np.trace(shrink_PS_gauss, axis1=1, axis2=2)
tr_SPinvS_RHS   = np.trace(shrink_PS_RHS,   axis1=1, axis2=2)
tr_SPinvS_DHS   = np.trace(shrink_PS_DHS,   axis1=1, axis2=2)
tr_SPinvS_DST   = np.trace(shrink_PS_DST,   axis1=1, axis2=2)



In [None]:
# Plot df_eff distributions
plt.figure(figsize=(8,4), dpi=150)
bins = 40
plt.hist(tr_R_gauss, bins=bins, alpha=0.5, label="Gauss")
plt.hist(tr_R_RHS,   bins=bins, alpha=0.5, label="RHS")
plt.hist(tr_R_DHS,   bins=bins, alpha=0.5, label="DHS")
plt.hist(tr_R_DST,   bins=bins, alpha=0.5, label="DST")
plt.xlabel("trace((I+G)^{-1}G)  [effective dof]")
plt.ylabel("count")
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(8,4), dpi=150)
bins = 40
plt.hist(tr_SPinvS_gauss, bins=bins, alpha=0.5, label="Gauss")
plt.hist(tr_SPinvS_RHS,   bins=bins, alpha=0.5, label="RHS")
plt.hist(tr_SPinvS_DHS,   bins=bins, alpha=0.5, label="DHS")
plt.hist(tr_SPinvS_DST,   bins=bins, alpha=0.5, label="DST")
plt.xlabel(r"$tr((P+S)^{-1}S)$")
plt.ylabel("Frequency")
plt.legend()
plt.tight_layout()
plt.show()

In [12]:
import numpy as np

def plot_point1_aligned(A, B, nameA="A", nameB="B",
                        H=16, p=8, use_abs=False, q_low=0.05, q_high=0.99):
    """
    Point (1): Best-scale–aligned difference.
      - Panel 1: A
      - Panel 2: c*·B  (c* = <A,B>_F / ||B||_F^2)
      - Panel 3: A - c*·B
      - Panel 4: (blank filler)
    """
    A = np.asarray(A, float); B = np.asarray(B, float)
    num = np.sum(A * B)
    den = np.sum(B * B) if np.sum(B * B) != 0 else 1.0
    c_star = num / den
    cosF = num / (np.linalg.norm(A, "fro") * (np.linalg.norm(B, "fro") + 1e-12))

    mats  = [A, c_star * B, A - c_star * B, np.zeros_like(A)]
    names = [
        f"{nameA}",
        f"{nameB} scaled (c*={c_star:.3g})",
        f"Aligned diff: {nameA} − c*·{nameB}\ncos_F={cosF:.3f}",
        "(unused)"
    ]
    visualize_models(mats, names, H=H, p=p, use_abs=use_abs, q_low=q_low, q_high=q_high)

def plot_point2_unit_energy(A, B, nameA="A", nameB="B",
                            H=16, p=8, use_abs=False, q_low=0.05, q_high=0.99):
    """
    Point (2): Unit-energy (Frobenius-normalized) side-by-side + difference.
      - Panel 1: A / ||A||_F
      - Panel 2: B / ||B||_F
      - Panel 3: (A/||A||_F) − (B/||B||_F)
      - Panel 4: (unused filler)
    """
    A = np.asarray(A, float); B = np.asarray(B, float)
    Af = A / (np.linalg.norm(A, "fro") + 1e-12)
    Bf = B / (np.linalg.norm(B, "fro") + 1e-12)

    mats  = [Af, Bf, Af - Bf, np.zeros_like(A)]
    names = [
        f"{nameA} / ||{nameA}||_F",
        f"{nameB} / ||{nameB}||_F",
        "Difference (unit-energy)",
        "(unused)"
    ]
    visualize_models(mats, names, H=H, p=p, use_abs=use_abs, q_low=q_low, q_high=q_high)


In [None]:
plot_point1_aligned(np.mean((shrink_PS_DHS), axis=0), np.mean((shrink_PS_RHS), axis=0), "Dirichlet–HS", "RHS")


In [None]:
plot_point1_aligned(np.mean((shrink_PS_DST), axis=0), np.mean((shrink_PS_RHS), axis=0), "Dirichlet–ST", "RHS")


In [None]:
plot_point1_aligned(np.mean((shrink_PS_DHS), axis=0), np.mean((shrink_PS_gauss), axis=0), "Dirichlet–HS", "Gaussian")


In [None]:
plot_point1_aligned(np.mean((shrink_PS_DST), axis=0), np.mean((shrink_PS_gauss), axis=0), "Dirichlet–ST", "Gaussian")


In [None]:
plot_point2_unit_energy(np.mean((shrink_PS_DHS), axis=0), np.mean((shrink_PS_RHS), axis=0), "Dirichlet–HS", "RHS")


In [None]:
plot_point2_unit_energy(np.mean((shrink_PS_DST), axis=0), np.mean((shrink_PS_RHS), axis=0), "Dirichlet–ST", "RHS")

In [None]:
plot_point2_unit_energy(np.mean((shrink_PS_DHS), axis=0), np.mean((shrink_PS_gauss), axis=0), "Dirichlet–HS", "Gauss")

In [None]:
plot_point2_unit_energy(np.mean((shrink_PS_DST), axis=0), np.mean((shrink_PS_gauss), axis=0), "Dirichlet–ST", "Gauss")

In [None]:
plot_point1_aligned(np.mean((shrink_PS_gauss), axis=0), np.mean((shrink_PS_RHS), axis=0), "Gauss", "RHS")

In [None]:
plot_point2_unit_energy(np.mean((shrink_PS_gauss), axis=0), np.mean((shrink_PS_RHS), axis=0), "Gauss", "RHS")

In [None]:
mats = [np.mean(shrink_PS_gauss, axis=0), np.mean(shrink_PS_RHS, axis=0), np.mean(shrink_PS_DHS, axis=0), np.mean(shrink_PS_DST, axis=0)]
names = ["Gaussian", "RHS", "Dirichlet–HS", "Dirichlet–ST"]
visualize_models(mats, names, H=16, p=8, use_abs=False, q_low=0.05, q_high=0.99)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import pdist, squareform

# ------------------------------------------------------
# Assume: shrink_PS_gauss etc. each are (4000, 128, 128)
# ------------------------------------------------------

models = {
    "Gaussian": shrink_PS_gauss,
    "RHS": shrink_PS_RHS,
    "Dirichlet–HS": shrink_PS_DHS,
    "Dirichlet–ST": shrink_PS_DST,
}

# --- 1️⃣ Compute posterior means of each operator ---
mean_ops = {name: np.mean(arr, axis=0) for name, arr in models.items()}

# --- 2️⃣ Pairwise cosine distances between posterior means ---
K_flat = np.stack([v.ravel() for v in mean_ops.values()])
dist_matrix = squareform(pdist(K_flat, metric="cosine"))

plt.figure(figsize=(6, 5))
sns.heatmap(
    dist_matrix,
    annot=True, fmt=".3f",
    xticklabels=list(mean_ops.keys()),
    yticklabels=list(mean_ops.keys()),
    cmap="mako", square=True, cbar_kws={"label": "Cosine distance"}
)
plt.title("Pairwise cosine distances between mean shrinkage operators")
plt.tight_layout()
plt.show()

# --- 3️⃣ Eigenvalue spectra across posterior draws ---
plt.figure(figsize=(7, 5))
for name, arr in models.items():
    eigs_all = np.linalg.eigvalsh(arr)   # shape (4000, 128)
    eigs_mean = np.mean(eigs_all, axis=0)
    eigs_std  = np.std(eigs_all, axis=0)
    idx = np.argsort(eigs_mean)[::-1]
    plt.plot(eigs_mean[idx], label=name)
    plt.fill_between(np.arange(len(eigs_mean)), 
                     eigs_mean[idx] - eigs_std[idx],
                     eigs_mean[idx] + eigs_std[idx],
                     alpha=0.2)
plt.xlabel("Eigenvalue index (sorted)")
plt.ylabel("Eigenvalue magnitude")
plt.legend()
plt.title("Eigenvalue spectra of shrinkage operators")
plt.tight_layout()
plt.show()


In [54]:
import numpy as np
import matplotlib.pyplot as plt

# --- konstanter og blokkindekser (tilpass hvis din vec-rekkefølge er annerledes)
H, p = 16, 8
BLOCKS = [slice(h*p, (h+1)*p) for h in range(H)]

def block_energy(U, blocks=BLOCKS):
    BE = np.empty((U.shape[1], len(blocks)))  # (modes, H)
    for b, sl in enumerate(blocks):
        BE[:, b] = (U[sl, :]**2).sum(axis=0)
    BE /= BE.sum(axis=1, keepdims=True)
    return BE  # (modes, H)

def evd_metrics(G):
    w, U = np.linalg.eigh(G)                 # G sym/PSD
    # sorter synkende på w
    order = np.argsort(w)[::-1]
    w, U = w[order], U[:, order]
    rho = w / (1.0 + w)
    m_eff = rho.sum()
    ipr = (U**4).sum(axis=0)                  # inverse participation ratio
    eff_support = 1.0 / ipr                   # effektiv støtte
    return dict(w=w, U=U, rho=rho, m_eff=m_eff, ipr=ipr, eff_supp=eff_support)

def m_eff_blocks_from_G(G):
    M  = evd_metrics(G)
    BE = block_energy(M['U'], BLOCKS)        # (modes, H)
    m_eff_b = (M['rho'][:, None] * BE).sum(axis=0)  # (H,)
    return m_eff_b

In [55]:
# ===== Forutsetter at du har disse =====
W2_gauss_samps = tanh_fit['Gaussian tanh']['posterior'].stan_variable("W_L")#[:100]
W2_RHS_samps = tanh_fit['Regularized Horseshoe tanh']['posterior'].stan_variable("W_L")#[:100]
W2_DHS_samps = tanh_fit['Dirichlet Horseshoe tanh']['posterior'].stan_variable("W_L")#[:100]
W2_DST_samps = tanh_fit['Dirichlet Student T tanh']['posterior'].stan_variable("W_L")#[:100]

S = 4000

# --- beregn m_eff per blokk for ALLE samples
m_eff_blocks_GAUSS = np.zeros((S, H))
m_eff_blocks_RHS   = np.zeros((S, H))
m_eff_blocks_DHS   = np.zeros((S, H))
m_eff_blocks_DST = np.zeros((S, H))
for s in range(S):
    m_eff_blocks_GAUSS[s] = m_eff_blocks_from_G(G_gauss[s])
    m_eff_blocks_RHS[s] = m_eff_blocks_from_G(G_RHS[s])
    m_eff_blocks_DHS[s]   = m_eff_blocks_from_G(G_DHS[s])
    m_eff_blocks_DST[s]   = m_eff_blocks_from_G(G_DST[s])
    

# --- klargjør |W2| i samme form
W2_GAUSS_flat = np.abs(np.atleast_2d(W2_gauss_samps).reshape(S, H))
W2_RHS_flat   = np.abs(np.atleast_2d(W2_RHS_samps).reshape(S, H))
W2_DHS_flat   = np.abs(np.atleast_2d(W2_DHS_samps).reshape(S, H))
W2_DST_flat   = np.abs(np.atleast_2d(W2_DST_samps).reshape(S, H))

# --- flate til 1D for scatter
x_gau = m_eff_blocks_GAUSS.ravel()
y_gau = W2_GAUSS_flat.ravel()
x_rhs = m_eff_blocks_RHS.ravel()
y_rhs = W2_RHS_flat.ravel()
x_dhs = m_eff_blocks_DHS.ravel()
y_dhs = W2_DHS_flat.ravel()
x_dst = m_eff_blocks_DST.ravel()
y_dst = W2_DST_flat.ravel()

In [None]:
plt.figure()
plt.scatter(x_gau, y_gau, label="Gaussian", s=8, alpha=0.35)
plt.scatter(x_rhs, y_rhs, label="RHS", s=8, alpha=0.35)
plt.scatter(x_dhs, y_dhs, label="DHS", s=8, alpha=0.35)
plt.scatter(x_dst, y_dst, label="DST", s=8, alpha=0.35)
plt.xlabel(r"$m_{\mathrm{eff}}^{(b)}$")
plt.ylabel(r"$|W_2|$")
plt.legend()
plt.show()

In [30]:
from utils.generate_data import load_abalone_regression_data
X, X_test, y, y_test = load_abalone_regression_data(standardized=False, frac=0.2)
# Coerce everything to plain float64 NumPy arrays
X      = np.asarray(X, dtype=float)
X_test = np.asarray(X_test, dtype=float)

# y often comes as a (n,1) DataFrame/array — flatten to (n,)
y      = np.asarray(y, dtype=float).reshape(-1)
y_test = np.asarray(y_test, dtype=float).reshape(-1)

In [None]:
from utils.kappa_matrix import extract_model_draws, compute_shrinkage_for_W_block, shrinkage_eigs_and_df
from utils.sparsity import local_prune_weights

def compute_shrinkage_with_pruning(
    X,
    W_all, b_all, v_all,          # (D,H,p), (D,H), (D,H)
    sigma_all, tau_w_all, tau_v_all,  # (D,), (D,), (D,)
    lambda_all,                   # (D,H,p)
    activation="tanh",
    return_mats=True,             # set False if you only want summaries
    include_b1_in_Sigma: bool = True,
    include_b2_in_Sigma: bool = True,
    sparsity = 0.9
):
    """
    Loop over draws and compute R=(P+S)^{-1}P per draw using your single-draw function.
    Returns:
      R_stack : (D, N, N) with N=H*p  (if return_mats=True, else None)
      r_eigs  : (D, N)  sorted eigenvalues in [0,1]
      df_eff  : (D,)    effective dof = tr(I-R) = N - tr(R)
    """
    D, H, p = W_all.shape
    N = H * p

    R_stack = np.empty((D, N, N)) if return_mats else None
    S_stack = np.empty((D, N, N)) if return_mats else None
    P_stack = np.empty((D, N, N)) if return_mats else None
    G_stack = np.empty((D, N, N)) if return_mats else None
    shrink_stack= np.empty((D, N, N)) if return_mats else None
    r_eigs  = np.empty((D, N))
    df_eff  = np.empty(D)

    for d in range(D):
        mask = local_prune_weights(W_all[d], sparsity_level=sparsity)
        W_pruned = mask[0]*W_all[d]
        R, P, S, Sigma_y, _, _ = compute_shrinkage_for_W_block(
            X=X,
            W0=W_pruned,
            b0=b_all[d],
            v0=v_all[d],
            noise=float(sigma_all[d]),
            tau_w=float(tau_w_all[d]),
            tau_v=float(tau_v_all[d]),
            lambda_tilde=lambda_all[d],
            activation=activation,
            include_b1_in_Sigma=include_b1_in_Sigma,
            include_b2_in_Sigma=include_b2_in_Sigma,
        )
        p = np.diag(P)                       
        P_inv_sqrt = np.diag(1.0/np.sqrt(p))         
        G = P_inv_sqrt @ S @ P_inv_sqrt 
        I = np.identity(N)
        shrink_mat = np.linalg.inv(I + G)@G

        if return_mats:
            R_stack[d] = R
            S_stack[d] = S
            P_stack[d] = P
            G_stack[d] = G
            shrink_stack[d] = shrink_mat
        
        r, df = shrinkage_eigs_and_df(P, S)
        r_eigs[d] = np.sort(r)
        df_eff[d] = df

    return R_stack, S_stack, P_stack, G_stack, shrink_stack, r_eigs, df_eff


W, b1, v, b2, noise, tau_w, tau_v, lambda_eff = extract_model_draws(
    tanh_fit, model='Gaussian tanh'
)
R_gauss, S_gauss, P_gauss, G_gauss, shrink_gauss, eigs_gauss, df_gauss = compute_shrinkage_with_pruning(
    X, W, b1, v, noise, tau_w, tau_v, lambda_eff,
    activation="tanh",
    include_b1_in_Sigma=True,
    include_b2_in_Sigma=True,
)
print("done with Gauss")

W, b1, v, b2, noise, tau_w, tau_v, lambda_eff = extract_model_draws(
    tanh_fit, model='Regularized Horseshoe tanh'
)

R_RHS, S_RHS, P_RHS, G_RHS, shrink_RHS, eigs_RHS, df_eff_RHS = compute_shrinkage_with_pruning(
    X, W, b1, v, noise, tau_w, tau_v, lambda_eff,
    activation="tanh",
    include_b1_in_Sigma=True,
    include_b2_in_Sigma=True,
)
print("done with RHS")

W, b1, v, b2, noise, tau_w, tau_v, lambda_eff = extract_model_draws(
    tanh_fit, model='Dirichlet Horseshoe tanh'
)
R_DHS, S_DHS, P_DHS, G_DHS, shrink_DHS, eigs_DHS, df_eff_DHS = compute_shrinkage_with_pruning(
    X, W, b1, v, noise, tau_w, tau_v, lambda_eff,
    activation="tanh",
    include_b1_in_Sigma=True,
    include_b2_in_Sigma=True,
)
print("done with DHS")

W, b1, v, b2, noise, tau_w, tau_v, lambda_eff = extract_model_draws(
    tanh_fit, model='Dirichlet Student T tanh'
)
R_DST, S_DST, P_DST, G_DST, shrink_DST, eigs_DST, df_eff_DST = compute_shrinkage_with_pruning(
    X, W, b1, v, noise, tau_w, tau_v, lambda_eff,
    activation="tanh",
    include_b1_in_Sigma=True,
    include_b2_in_Sigma=True,
)
print("done with DST")

In [34]:
# --- Traces as distributions (df_eff = tr(R) vs total shrinkage = tr(I-R)) ---
import matplotlib.pyplot as plt

# Effective dof: trace of (I+G)^{-1}G per draw
# tr_R_gauss = np.trace(shrink_G_gauss, axis1=1, axis2=2)
# tr_R_RHS   = np.trace(shrink_G_RHS,   axis1=1, axis2=2)
# tr_R_DHS   = np.trace(shrink_G_DHS,   axis1=1, axis2=2)
# tr_R_DST   = np.trace(shrink_G_DST,   axis1=1, axis2=2)

# If you also want “total shrinkage”, use your SP_inv_S_* stacks (I - R):
tr_SPinvS_gauss = np.trace(shrink_gauss, axis1=1, axis2=2)
tr_SPinvS_RHS   = np.trace(shrink_RHS,   axis1=1, axis2=2)
tr_SPinvS_DHS   = np.trace(shrink_DHS,   axis1=1, axis2=2)
tr_SPinvS_DST   = np.trace(shrink_DST,   axis1=1, axis2=2)



In [None]:
plt.figure(figsize=(8,4), dpi=150)
bins = 40
plt.hist(tr_SPinvS_gauss, bins=bins, alpha=0.5, label="Gauss")
plt.hist(tr_SPinvS_RHS,   bins=bins, alpha=0.5, label="RHS")
plt.hist(tr_SPinvS_DHS,   bins=bins, alpha=0.5, label="DHS")
plt.hist(tr_SPinvS_DST,   bins=bins, alpha=0.5, label="DST")
plt.xlabel(r"$tr((P+S)^{-1}S)$")
plt.ylabel("Frequency")
plt.legend()
plt.tight_layout()
plt.show()

## The below does not seem to work!

In [49]:
from utils.generate_data import load_abalone_regression_data
X, X_test, y, y_test = load_abalone_regression_data(standardized=False, frac=1.0)
# Coerce everything to plain float64 NumPy arrays
X      = np.asarray(X, dtype=float)
X_test = np.asarray(X_test, dtype=float)

# y often comes as a (n,1) DataFrame/array — flatten to (n,)
y      = np.asarray(y, dtype=float).reshape(-1)
y_test = np.asarray(y_test, dtype=float).reshape(-1)


In [None]:
from utils.kappa_matrix import build_hidden_and_jacobian_W, build_Sigma_y, build_S, build_P_from_lambda_tau, shrinkage_matrix_stable, extract_model_draws

def solve_psd_pinv(S, g, rtol=1e-10):
    # symmetric eigendecomp
    evals, Q = np.linalg.eigh(S)
    # threshold small eigenvalues
    tol = rtol * max(evals.max(), 1.0)
    keep = evals > tol
    if not np.any(keep):
        return np.zeros_like(g)
    inv_eigs = np.zeros_like(evals)
    inv_eigs[keep] = 1.0 / evals[keep]
    # S^+ g = Q diag(inv_eigs) Q^T g
    return Q @ (inv_eigs * (Q.T @ g))

# ----- Low-rank builder that includes biases exactly like build_Sigma_y -----
def build_U(Phi_mat, tau_v, J_b1=None, J_b2=None, include_b1=True, include_b2=True):
    cols = [np.sqrt(tau_v**2) * Phi_mat]          # (n, H)
    if include_b1 and (J_b1 is not None):
        cols.append(J_b1)                          # (n, H)
    if include_b2 and (J_b2 is not None):
        cols.append(J_b2.reshape(-1, 1))           # (n, 1)
    return np.concatenate(cols, axis=1) if len(cols) > 1 else cols[0]  # (n, r)

# ----- Woodbury apply: returns Σ_y^{-1} B without forming Σ_y -----
def woodbury_apply(U, sigma2, B):
    # U: (n, r), B: (n,) or (n, k)
    n = U.shape[0]
    B = B.reshape(n, -1)  # (n, k)
    inv_sigma2 = 1.0 / sigma2
    UtU = U.T @ U                        # (r, r)
    A = np.eye(UtU.shape[0]) + inv_sigma2 * UtU
    # solve A X = inv_sigma2 * U^T B for X
    RHS = inv_sigma2 * (U.T @ B)         # (r, k)
    X = np.linalg.solve(A, RHS)          # (r, k)
    out = inv_sigma2 * (B - U @ X)       # (n, k)
    return out if out.shape[1] > 1 else out.ravel()

# ----- Fast mean with biases via Woodbury + your precomputed (R, P) -----
def compute_linearized_mean_fast(
    X, y,
    W_1, b_1, W_2, b_2,
    noise_all, tau_w_all, tau_v_all,
    lambda_all,
    R_all=None,                 # optional: (D, N, N) with R=(P+S)^{-1}P
    shrink_PS_all=None,         # optional: (D, N, N) with (P+S)^{-1}S
    P_all=None,            # (D, N) diagonal of P
    activation="tanh",
    include_b1_in_Sigma=True,
    include_b2_in_Sigma=True,
    return_mats=False,
    D_lim = None
):
    D, H, p = W_1.shape
    if D_lim is not None:
        D=D_lim
    else:
        pass
    N = H * p
    n = y.shape[0]
    y = np.asarray(y, float).reshape(n)

    w_bar_stack = np.empty((D, N))
    R_stack = np.empty((D, N, N)) if return_mats else None

    for d in range(D):
        Phi_mat, JW, Jb1, Jb2 = build_hidden_and_jacobian_W(
            X, W_1[d], b_1[d], W_2[d], activation=activation
        )

        U = build_U(
            Phi_mat, tau_v_all[d],
            J_b1=Jb1, J_b2=Jb2,
            include_b1=include_b1_in_Sigma,
            include_b2=include_b2_in_Sigma,
        )

        w0_vec = W_1[d].reshape(-1)
        y_star = y + (JW @ w0_vec) + (Jb1 @ b_1[d])
        
        Sigma_y = build_Sigma_y(Phi_mat,
            tau_v=tau_v_all[d],
            noise=noise_all[d],
            J_b1=Jb1,
            J_b2=Jb2,
            include_b1=include_b1_in_Sigma,
            include_b2=include_b2_in_Sigma
        )
        r_vec = np.linalg.solve(Sigma_y, y_star) 
        #r_vec = woodbury_apply(U, noise_all[d]**2, y_star)
        g = JW.T @ r_vec
        # z = P^{-1} g
        if P_all is not None:
            P_diag_all = np.diag(P_all[d])
            z = g / P_diag_all[d]
        else:
            P = build_P_from_lambda_tau(lambda_all[d], tau_w=tau_w_all[d])
            z = g / np.diag(P)

        if R_all is not None:
            bar_w = R_all[d] @ z
            if return_mats:
                R_stack[d] = R_all[d]
        elif shrink_PS_all is not None:
            # R z = z - shrink_PS z  (no need to build R)
            bar_w = z - (shrink_PS_all[d] @ z)
            if return_mats:
                # If caller really wants R back, construct it efficiently:
                R_tmp = -shrink_PS_all[d].copy()
                np.fill_diagonal(R_tmp, 1.0 + np.diag(R_tmp))
                R_stack[d] = R_tmp
        else:
            # Fallback (slow): compute R from scratch
            P = build_P_from_lambda_tau(lambda_all[d], tau_w=tau_w_all[d])
            S = build_S(JW, U @ U.T + (noise_all[d]**2) * np.eye(n))
            R = shrinkage_matrix_stable(P, S, want_R=True)  # make your helper return R directly
            bar_w = R @ z
            if return_mats:
                R_stack[d] = R

        w_bar_stack[d] = bar_w

    return (R_stack if return_mats else None), w_bar_stack



W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(tanh_fit, model='Gaussian tanh')
# Gaussian
_, w_bar_gauss = compute_linearized_mean_fast(
    X, y, W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde,
    shrink_PS_all=shrink_PS_gauss, P_all=None, activation="tanh", D_lim=100
)

print("Done Gauss")

W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(tanh_fit, model='Regularized Horseshoe tanh')
# RHS
_, w_bar_RHS = compute_linearized_mean_fast(
    X, y, W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde,
    shrink_PS_all=shrink_PS_RHS, P_all=None, activation="tanh", D_lim=100
)


W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(tanh_fit, model='Dirichlet Horseshoe tanh')
# DHS
_, w_bar_DHS = compute_linearized_mean_fast(
    X, y, W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde,
    shrink_PS_all=shrink_PS_DHS, P_all=None, activation="tanh", D_lim=100
)


W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(tanh_fit, model='Dirichlet Student T tanh')

# DST
_, w_bar_DST = compute_linearized_mean_fast(
    X, y, W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde,
    shrink_PS_all=shrink_PS_DST, P_all=None, activation="tanh", D_lim=100
)



In [46]:
def align_and_compare(W_all, v_all, w_bar_stack, sort_key="abs_v"):
    """
    Align signs & permutations across draws before comparing linearized mean with posterior mean.

    Inputs
    ------
    W_all        : array-like, shape (D, H, p) or (D, p, H) or with stray singleton dims.
    v_all        : array-like, shape (D, H) or (D, H, 1) or similar (length H per draw).
    w_bar_stack  : array-like, shape (D, H*p) OR (D, H, p) OR (D, 1, H*p), etc.

    Returns
    -------
    W_fix        : (D, H, p)   sign/permutation aligned
    v_fix        : (D, H)
    wbar_fix     : (D, H, p)
    summary      : dict with RMSE, Corr, CosSim, SignAgree (means vs means in aligned basis)
    """
    import numpy as np

    W_all = np.asarray(W_all)
    v_all = np.asarray(v_all)
    w_bar_stack = np.asarray(w_bar_stack)

    D = W_all.shape[0]

    # --- infer H from v (source of truth) ---
    v0 = np.squeeze(v_all[0]).ravel()
    H = v0.size
    if H == 0:
        raise ValueError("v_all[0] seems empty; cannot infer H.")
    # infer p from w_bar_stack length
    wb0 = np.squeeze(w_bar_stack[0]).ravel()
    if wb0.size % H != 0:
        # fallback: try infer p from W_all[0] after squeezing
        W0 = np.squeeze(W_all[0])
        if W0.ndim != 2:
            # try to drop any singleton dims
            W0 = W0.reshape([s for s in W0.shape if s != 1])
        if W0.ndim != 2:
            raise ValueError(f"Cannot infer (H,p). v length={H}, but w_bar_stack[0] has {wb0.size} elems "
                             f"and W_all[0] has shape {np.squeeze(W_all[0]).shape}.")
        h, p_candidate = W0.shape
        if h != H and p_candidate == H:
            p = h
        else:
            p = p_candidate
    else:
        p = wb0.size // H

    N = H * p

    # alloc outputs
    W_fix = np.empty((D, H, p), dtype=float)
    v_fix = np.empty((D, H), dtype=float)
    wbar_fix = np.empty((D, H, p), dtype=float)

    def coerce_W(Wd, H, p):
        """Return Wd as (H,p). Accepts (H,p), (p,H), or with singleton dims."""
        A = np.asarray(Wd, dtype=float)
        A = np.squeeze(A)
        if A.ndim == 2:
            h, q = A.shape
            if h == H and q == p:
                return A
            if h == p and q == H:
                return A.T
            # If one matches H, try reshape to (H, -1)
            if h == H and h*q == H*p:
                return A.reshape(H, p)
            if q == H and h*q == H*p:
                return A.T.reshape(H, p)
            raise ValueError(f"Cannot coerce W of shape {A.shape} to (H,p)=({H},{p}).")
        elif A.ndim == 3 and 1 in A.shape:
            # squeeze singleton and recurse
            return coerce_W(np.squeeze(A), H, p)
        else:
            raise ValueError(f"Unexpected W ndim={A.ndim}, shape={A.shape}")

    def coerce_v(vd, H):
        """Return vd as (H,)"""
        v = np.asarray(vd, dtype=float).squeeze().ravel()
        if v.size != H:
            raise ValueError(f"v has size {v.size}, expected H={H}.")
        return v

    def coerce_wbar_row(wbd, H, p):
        """Return wbar row as (H,p) from (N,) or already (H,p)."""
        w = np.asarray(wbd, dtype=float).squeeze().ravel()
        if w.size == H * p:
            return w.reshape(H, p)
        # already 2D?
        W2 = np.asarray(wbd, dtype=float).squeeze()
        if W2.ndim == 2 and W2.shape == (H, p):
            return W2
        raise ValueError(f"w_bar row has {w.size} elems but H*p={H*p} and not (H,p).")

    for d in range(D):
        # coerce shapes
        Wd = coerce_W(W_all[d], H, p)          # (H,p)
        vd = coerce_v(v_all[d], H)             # (H,)
        wbd = coerce_wbar_row(w_bar_stack[d], H, p)

        # 1) sign fix so v >= 0
        s = np.sign(vd)
        s[s == 0.0] = 1.0
        Wd = Wd * s[:, None]
        wbd = wbd * s[:, None]
        vd = np.abs(vd)

        # 2) permute units by a stable key
        if sort_key == "abs_v":
            idx = np.argsort(-vd)  # descending |v|
        elif sort_key == "abs_v_times_rownorm":
            idx = np.argsort(-(vd * np.linalg.norm(Wd, axis=1)))
        else:
            raise ValueError(f"Unknown sort_key: {sort_key}")

        W_fix[d] = Wd[idx]
        wbar_fix[d] = wbd[idx]
        v_fix[d] = vd[idx]

    # Compare means in aligned basis
    w_post_mean = W_fix.reshape(D, -1).mean(axis=0)   # (N,)
    w_lin_mean  = wbar_fix.reshape(D, -1).mean(axis=0)

    rmse = float(np.sqrt(np.mean((w_lin_mean - w_post_mean)**2)))
    corr = float(np.corrcoef(w_lin_mean, w_post_mean)[0, 1])
    cos  = float(np.dot(w_lin_mean, w_post_mean) /
                 (np.linalg.norm(w_lin_mean) * np.linalg.norm(w_post_mean)))
    sign_agree = float(np.mean(np.sign(w_lin_mean) == np.sign(w_post_mean)))

    summary = dict(RMSE=rmse, Corr=corr, CosSim=cos, SignAgree=sign_agree,
                   H=H, p=p, N=N)
    return W_fix, v_fix, wbar_fix, summary


In [None]:
# --- Helper: pick a "MAP-like" representative draw and plot MAP vs. \bar{w} ---
import numpy as np
import matplotlib.pyplot as plt

def select_map_like_index(W_fix: np.ndarray) -> int:
    """
    Returns the index of the draw whose aligned W is closest (in Frobenius norm)
    to the aligned posterior mean -- a robust MAP/medoid proxy.
    W_fix: (D, H, p) aligned weights (output of align_and_compare)
    """
    D = W_fix.shape[0]
    mu = W_fix.reshape(D, -1).mean(axis=0)  # posterior mean in aligned basis
    diffs = W_fix.reshape(D, -1) - mu[None, :]
    d2 = np.einsum('di,di->d', diffs, diffs)  # squared distances
    return int(np.argmin(d2))

def plot_map_vs_barw(W_fix: np.ndarray, wbar_fix: np.ndarray, title: str = "", alpha=0.7):
    """
    Overlay scatter: MAP-like draw's W (dots) vs the same draw's \bar{w} (crosses).
    Both arrays must be aligned: (D, H, p). We auto-pick a representative draw.
    """
    D, H, p = W_fix.shape
    idx = select_map_like_index(W_fix)  # representative draw
    w_map = W_fix[idx].reshape(-1)
    w_bar = wbar_fix[idx].reshape(-1)
    
    eps = 1e-1                          # Small threshold to see non-zero weights

    x = np.arange(1, H*p + 1)
    plt.figure(figsize=(10, 3.5), dpi=150)
    plt.scatter(x, w_map, s=12, marker='o', label="MAP-like $w$", alpha=alpha)
    plt.scatter(x, w_bar, s=18, marker='x', label=r"Linearized $\bar{w}$", alpha=alpha)

    # light vertical guides between hidden units
    for h in range(1, H):
        plt.axvline(h*p + 0.5, color='0.85', lw=1, zorder=0)
    
    plt.axhline(eps, color='0.85', lw=1, zorder=0)
    plt.axhline(-eps, color='0.85', lw=1, zorder=0)

    plt.xlabel("parameter index (after alignment)")
    plt.ylabel("value")
    plt.title(title if title else "MAP-like $w$ vs linearized $\~w$")
    plt.legend()
    plt.tight_layout()
    plt.show()


In [49]:
W_all_gauss = tanh_fit['Gaussian tanh']['posterior'].stan_variable("W_1")[:100]
v_all_gauss = tanh_fit['Gaussian tanh']['posterior'].stan_variable("W_L")[:100]

W_all_RHS = tanh_fit['Regularized Horseshoe tanh']['posterior'].stan_variable("W_1")[:100]
v_all_RHS = tanh_fit['Regularized Horseshoe tanh']['posterior'].stan_variable("W_L")[:100]

W_all_DHS = tanh_fit['Dirichlet Horseshoe tanh']['posterior'].stan_variable("W_1")[:100]
v_all_DHS = tanh_fit['Dirichlet Horseshoe tanh']['posterior'].stan_variable("W_L")[:100]

W_all_DST = tanh_fit['Dirichlet Student T tanh']['posterior'].stan_variable("W_1")[:100]
v_all_DST = tanh_fit['Dirichlet Student T tanh']['posterior'].stan_variable("W_L")[:100]

In [None]:
# --- Gaussian: align and plot ---
W_fix_g, v_fix_g, wbar_fix_g, summary_g = align_and_compare(W_all_gauss, v_all_gauss, w_bar_gauss, sort_key="abs_v")
print("Gaussian summary:", summary_g)
plot_map_vs_barw(W_fix_g, wbar_fix_g, title="Gaussian prior: MAP-like $w$ vs linearized $\\bar{w}$")


In [None]:
# --- Regularized Horseshoe: align and plot ---
W_fix_r, v_fix_r, wbar_fix_r, summary_r = align_and_compare(W_all_RHS, v_all_RHS, w_bar_RHS, sort_key="abs_v")
print("RHS summary:", summary_r)
plot_map_vs_barw(W_fix_r, wbar_fix_r, title="RHS prior: MAP-like $w$ vs linearized $\\bar{w}$")


In [None]:
# --- Dirichlet Horseshoe & Dirichlet Student-t: align and plot ---
W_fix_dhs, v_fix_dhs, wbar_fix_dhs, summary_dhs = align_and_compare(W_all_DHS, v_all_DHS, w_bar_DHS, sort_key="abs_v")
print("DHS summary:", summary_dhs)
plot_map_vs_barw(W_fix_dhs, wbar_fix_dhs, title="DHS prior: MAP-like $w$ vs linearized $\\bar{w}$")


In [None]:

W_fix_dst, v_fix_dst, wbar_fix_dst, summary_dst = align_and_compare(W_all_DST, v_all_DST, w_bar_DST, sort_key="abs_v")
print("DST summary:", summary_dst)
plot_map_vs_barw(W_fix_dst, wbar_fix_dst, title="DST prior: MAP-like $w$ vs linearized $\\bar{w}$")
