In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
#from sklearn.metrics import mean_squared_errosr
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
results_dir_priors = "results/priors/single_layer/tanh/friedman"
results_dir_posteriors = "results/regression/single_layer/tanh/friedman"

prior_names = ["Dirichlet Horseshoe", "Regularized Horseshoe", "Dirichlet Student T", "Gaussian"]
posterior_names = ["Dirichlet Horseshoe tanh", "Regularized Horseshoe tanh", "Dirichlet Student T tanh", "Gaussian tanh"]


prior_N100_fits = get_model_fits(
    config="Friedman_N100_p10_sigma1.00_seed1",
    results_dir=results_dir_priors,
    models=prior_names,
    include_prior=False,
)

prior_N200_fits = get_model_fits(
    config="Friedman_N200_p10_sigma1.00_seed2",
    results_dir=results_dir_priors,
    models=prior_names,
    include_prior=False,
)

prior_N500_fits = get_model_fits(
    config="Friedman_N500_p10_sigma1.00_seed11",
    results_dir=results_dir_priors,
    models=prior_names,
    include_prior=False,
)
    
posterior_N100_fits = get_model_fits(
    config="Friedman_N100_p10_sigma1.00_seed1",
    results_dir=results_dir_posteriors,
    models=posterior_names,
    include_prior=False,
)

posterior_N200_fits = get_model_fits(
    config="Friedman_N200_p10_sigma1.00_seed2",
    results_dir=results_dir_posteriors,
    models=posterior_names,
    include_prior=False,
)

posterior_N500_fits = get_model_fits(
    config="Friedman_N500_p10_sigma1.00_seed11",
    results_dir=results_dir_posteriors,
    models=posterior_names,
    include_prior=False,
)


In [3]:
path = "datasets/friedman/Friedman_N500_p10_sigma1.00_seed11.npz"
data = np.load(path)
X = data['X_train']
y = data['y_train']

In [45]:
import numpy as np
from typing import Tuple, Callable

# ---------- Aktivasjon og deriverte ----------

def get_activation(activation: str = "tanh") -> Tuple[Callable, Callable]:
    if activation == "tanh":
        phi = np.tanh
        def dphi(a): return 1.0 - np.tanh(a)**2
    elif activation == "relu":
        def phi(a): return np.maximum(0.0, a)
        def dphi(a): return (a > 0.0).astype(a.dtype)
    else:
        raise ValueError(f"Unsupported activation: {activation}")
    return phi, dphi

# ---------- H(w0) og J_W(w0, v0) ----------

def build_hidden_and_jacobian_W(
    X: np.ndarray,               # (n, p)
    W0: np.ndarray,              # (H, p)  -- vekter i referansepunktet w0
    b0: np.ndarray,              # (H,)    -- bias i referansepunktet w0
    v0: np.ndarray,              # (H,)    -- utgangsvekter i referansepunktet v0
    activation: str = "tanh",
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Returnerer:
      H  : (n, H)         = H(w0)
      JW : (n, H*p)       = d(H(w)v)/d vec(W) |_(w0, v0), kolonner ordnet som (h=0..H-1, j=0..p-1)
    """
    n, p = X.shape
    H, pW = W0.shape
    assert pW == p
    phi, dphi = get_activation(activation)

    # Pre- og post-aktivert
    A = X @ W0.T + b0[None, :]        # (n, H), a_{i,h}
    Phi_mat = phi(A)                     # (n, H), h_{i,h}
    dphiA = dphi(A)                   # (n, H)

    # J_W: df/dW_{h,j} = v_h * dphi(a_{i,h}) * x_{i,j}
    # For hver node h bygger vi et (n, p)-bidrag og flater ut langs j, og stabler så langs h.
    JW_blocks = []
    for h in range(H):
        # (n,1) * (1,p) -> (n,p)
        block_h = (v0[h] * dphiA[:, [h]]) * X 
        JW_blocks.append(block_h.reshape(n, p))
    # Stack kolonnevis i rekkefølge (h, j) -> (n, H*p)
    Jb = dphiA * v0[None, :]     # (n, H)
    JW = np.hstack([B for B in JW_blocks])
    Joutb = np.ones(n)        # (n,)
    return Phi_mat, JW, Jb, Joutb

# ---------- Sigma_y og P ----------

# def build_Sigma_y(
#     Phi_mat: np.ndarray,     # (n, H) = H(w0)
#     tau_v: float,         # prior std for v
#     J_b1: np.ndarray,      # Jacobian for bias
#     J_b2: np.ndarray,
#     noise: float          # støy std i likelihood
# ) -> np.ndarray:
#     """
#     Σ_y = J_bJ_b^T + τ_v^2 H H^T + σ^2 I_n
#     """
#     n = Phi_mat.shape[0]
#     Sigma_y = (tau_v**2) * (Phi_mat @ Phi_mat.T) + (J_b1@J_b1.T) + np.outer(J_b2, J_b2) + (noise**2) * np.eye(n)
#     return Sigma_y #(tau_v**2) * (Phi_mat @ Phi_mat.T) + (J_b1@J_b1.T) + np.outer(J_b2, J_b2) + (noise**2) * np.eye(n)

def build_Sigma_y(
    Phi_mat: np.ndarray,   # (n, H) = H(w0)
    tau_v: float,          # prior std for v
    noise: float,          # likelihood std
    J_b1: np.ndarray = None,     # (n, H), optional
    J_b2: np.ndarray = None,     # (n,),   optional
    include_b1: bool = True,
    include_b2: bool = True,
) -> np.ndarray:
    """
    Σ_y = τ_v^2 ΦΦ^T + [J_b1 J_b1^T if include_b1] + [J_b2 J_b2^T if include_b2] + σ^2 I_n
    """
    n = Phi_mat.shape[0]
    Sigma_y = (tau_v**2) * (Phi_mat @ Phi_mat.T) + (noise**2) * np.eye(n)

    if include_b1 and (J_b1 is not None):
        Sigma_y = Sigma_y + (J_b1 @ J_b1.T)

    if include_b2 and (J_b2 is not None):
        Sigma_y = Sigma_y + np.outer(J_b2, J_b2)

    return Sigma_y


def build_P_from_lambda_tau(
    lambda_tilde: np.ndarray,  # (H, p) lokale skalaer for W
    tau_w: float               # global skala for w
) -> np.ndarray:
    """
    P = τ_w^{-2} Λ^{-1} der Λ = diag(λ^2) for konsistens med uttrykket 1/(1 + τ^2 λ^2 s).
    Dvs. diag(P) = 1 / (τ_w^2 * λ^2).
    Returnerer P som (H*p, H*p) diagonalmatrise.
    """
    lam_vec = lambda_tilde.reshape(-1)          # (H*p,)
    diagP = 1.0 / ( (tau_w**2) * (lam_vec) ) # (H*p,)
    return np.diag(diagP)

# ---------- S, shrinkage-matrise R = (P+S)^{-1} P ----------

def build_S(JW: np.ndarray, Sigma_y: np.ndarray) -> np.ndarray:
    """
    S = J_W^T Σ_y^{-1} J_W  (Hp x Hp).
    Løser via lineær solve for stabilitet: X = Σ_y^{-1} J_W = solve(Σ_y, J_W).
    """
    X = np.linalg.solve(Sigma_y, JW)       # (n, Hp)
    return JW.T @ X                        # (Hp, Hp)

def shrinkage_matrix(P: np.ndarray, S: np.ndarray) -> np.ndarray:
    """
    R = (P+S)^{-1} P. Bruk Cholesky når mulig.
    Løser (P+S) * R = P for R.
    """
    A = P + S
    # Robust fallback hvis Cholesky feiler
    try:
        L = np.linalg.cholesky(A)
        # L Y = P  -> Y
        Y = np.linalg.solve(L, P)
        # L^T R = Y -> R
        R = np.linalg.solve(L.T, Y)
    except np.linalg.LinAlgError:
        R = np.linalg.solve(A, P)
    return R

def shrinkage_matrix_stable(P, S, jitter=0.0):
    """
    Stabil beregning av R = (P+S)^{-1} P via
    R = P^{1/2} (I + P^{-1/2} S P^{-1/2})^{-1} P^{1/2}.
    Krever at P er diagonal (positiv).
    """
    d = np.diag(P).astype(float)
    # Guardrails: ingen nuller/NaN/negativ
    eps = 1e-12
    d = np.clip(d, eps, np.finfo(float).max)
    Phalf    = np.diag(np.sqrt(d))
    Pinvhalf = np.diag(1.0 / np.sqrt(d))

    M = Pinvhalf @ S @ Pinvhalf
    # Jitter for SPD-sikkerhet (skader ikke i praksis)
    if jitter > 0:
        M = M + jitter * np.eye(M.shape[0])

    # (I + M) er SPD -> Cholesky
    I = np.eye(M.shape[0])
    L = np.linalg.cholesky(I + M)
    # (I+M)^{-1} P^{1/2} = (L^T)^{-1} (L)^{-1} P^{1/2}
    Z = np.linalg.solve(L, Phalf)
    W = np.linalg.solve(L.T, Z)
    # R = P^{1/2} * W
    #R = Phalf @ W
    R = Pinvhalf @ W
    # Symmetrer (numerisk)
    R = 0.5 * (R + R.T)
    return R

def shrinkage_eigs_and_df(P, S):
    """Returner r-eigenverdier og df_eff i P-whitnede koordinater."""
    d = np.diag(P).astype(float)
    eps = 1e-12
    Pinvhalf = np.diag(1.0 / np.sqrt(np.maximum(d, eps)))

    M = Pinvhalf @ S @ Pinvhalf          # SPD
    mu = np.linalg.eigvalsh(M)           # >= 0
    r = 1.0 / (1.0 + mu)                 # i (0,1]
    df_eff = np.sum(1.0 - r)             # = sum mu/(1+mu) >= 0
    return r, df_eff

# def extract_model_draws(fit_dict, model: str):
#     """
#     Returnerer ALL draws, med 'lambda_all' definert som EFFEKTIV VARIANSFAKTOR per vekt:
#       Gaussian:                 lambda_all = 1
#       Regularized Horseshoe:    lambda_all = lambda_tilde
#       Dirichlet (DHS/DST):      lambda_all = lambda_tilde_data * phi_data

#     Shapes:
#       W_all      : (D, H, p)
#       b_all      : (D, H)
#       v_all      : (D, H)
#       c_all      : (D,)
#       sigma_all  : (D,)
#       tau_w_all  : (D,)
#       tau_v_all  : (D,)   (ones if not in fit)
#       lambda_all : (D, H, p)  <-- effektiv variansfaktor
#     """
#     post = fit_dict[model]['posterior']

#     # W_1: (D, p, H) -> (D, H, p)
#     W_1 = np.asarray(post.stan_variable("W_1"))
#     W_all = np.transpose(W_1, (0, 2, 1))
#     D, H, p = W_all.shape

#     # W_L: (D, H, out_nodes=1) -> (D, H)
#     W_L = np.asarray(post.stan_variable("W_L"))
#     v_all = W_L.reshape(D, -1)

#     # hidden_bias: (D, 1, H) -> (D, H)
#     b_1 = np.asarray(post.stan_variable("hidden_bias"))
#     b_all = b_1.reshape(D, -1)

#     # output_bias: (D, 1) -> (D,)
#     b_2 = np.asarray(post.stan_variable("output_bias"))
#     c_all = b_2.reshape(D)

#     # sigma: (D,)
#     sigma_all = np.asarray(post.stan_variable("sigma")).reshape(D)

#     # Modell-flagg
#     is_gauss      = ("Gaussian" in model)
#     is_rhs        = ("Regularized Horseshoe" in model)
#     is_dirichlet  = ("Dirichlet" in model) or ("DST" in model)

#     # tau_w / tau_v
#     if is_gauss:
#         tau_w_all = np.ones(D)
#         tau_v_all = np.ones(D)
#     else:
#         tau_w_all = np.asarray(post.stan_variable("tau")).reshape(D)
#         try:
#             tau_v_all = np.asarray(post.stan_variable("tau_v")).reshape(D)
#         except Exception:
#             tau_v_all = np.ones(D)

#     # Effektiv lokal variansfaktor lambda_all
#     if is_gauss:
#         lambda_all = np.ones((D, H, p))
#     else:
#         lam_name = "lambda_tilde" if is_rhs else "lambda_tilde_data"
#         lam = np.asarray(post.stan_variable(lam_name))
#         # Bring til (D, H, p)
#         if lam.shape[1:] == (H, p):
#             lam_var = lam
#         else:
#             lam_var = np.transpose(lam, (0, 2, 1))

#         if is_dirichlet:
#             # phi_data: sannsynlige shapes (D, H, p) eller (D, p, H)
#             phi = np.asarray(post.stan_variable("phi_data"))
#             if phi.shape[1:] == (H, p):
#                 phi_hp = phi
#             else:
#                 phi_hp = np.transpose(phi, (0, 2, 1))
#             # Stan: stddev = tau * sqrt(lambda_tilde) * sqrt(phi)
#             #  => var = tau^2 * lambda_tilde * phi
#             lambda_all = lam_var * phi_hp
#         else:
#             # RHS: var = tau^2 * lambda_tilde
#             lambda_all = lam_var

#     return W_all, b_all, v_all, c_all, sigma_all, tau_w_all, tau_v_all, lambda_all

def extract_model_draws(
    fit_dict,
    model: str,
    *,
    lambda_effective_candidates = ("lambda_tilde", "lambda_tilde_data"),
    lambda_raw_candidates       = ("lambda", "lambda_data"),
    include_phi_for_dirichlet: bool = True,
    phi_name: str = "phi_data",
    lambda_kind: str = "effective",
):
    """
    Returns draws with a flexible way to pick lambda:

      W_all      : (D, H, p)
      b_all      : (D, H)
      v_all      : (D, H)
      c_all      : (D,)
      sigma_all  : (D,)
      tau_w_all  : (D,)
      tau_v_all  : (D,)   (ones if not present)
      lambda_all : (D, H, p)       <-- chosen lambda (effective/raw) per `lambda_kind`
      [lambda_raw_all]             <-- ONLY if lambda_kind == 'both'

    Conventions:
    - 'effective' lambda = the *regularized* local factor actually used in the weight std/var
      (e.g., `lambda_tilde` or `lambda_tilde_data`), optionally multiplied by `phi_data`
      for Dirichlet-type models if `include_phi_for_dirichlet=True`.
    - 'raw' lambda = the *unregularized* half-Cauchy parameter (e.g., `lambda` or `lambda_data`).

    Notes:
    - If the model looks Gaussian (by name or because 'tau' is absent), lambdas default to ones.
    - Shapes are coerced to (D,H,p) when possible; transposes are handled automatically.
    """

    post = fit_dict[model]['posterior']

    def _stan_var_or_none(name):
        try:
            return np.asarray(post.stan_variable(name))
        except Exception:
            return None

    def _coerce_DHp(arr, D, H, p):
        """Coerce Stan draws to shape (D,H,p). Accepts (D,H,p) or (D,p,H) or (D,p,H,1)/(D,1,H,p)."""
        if arr is None:
            return None
        shp = arr.shape
        if shp == (D, H, p):
            return arr
        if shp == (D, p, H):
            return np.transpose(arr, (0, 2, 1))
        # Common fallbacks (rare):
        if len(shp) == 4 and shp[0] == D:
            # drop singleton dims and retry
            squeezed = np.squeeze(arr)
            return _coerce_DHp(squeezed, D, H, p)
        raise ValueError(f"Cannot coerce lambda/phi array of shape {shp} to (D,{H},{p}).")

    # === Core weights/bias/sigma ===
    # W_1: (D, p, H) -> (D, H, p)
    W_1 = _stan_var_or_none("W_1")
    if W_1 is None:
        raise ValueError("Missing 'W_1' in posterior.")
    W_all = np.transpose(W_1, (0, 2, 1))
    D, H, p = W_all.shape

    # W_L: (D, H, out_nodes=1) -> (D, H)
    W_L = _stan_var_or_none("W_L")
    if W_L is None:
        raise ValueError("Missing 'W_L' in posterior.")
    v_all = W_L.reshape(D, -1)

    # hidden_bias: (D, 1, H) -> (D, H)
    b_1 = _stan_var_or_none("hidden_bias")
    if b_1 is None:
        raise ValueError("Missing 'hidden_bias' in posterior.")
    b_all = b_1.reshape(D, -1)

    # output_bias: (D, 1) -> (D,)
    b_2 = _stan_var_or_none("output_bias")
    if b_2 is None:
        raise ValueError("Missing 'output_bias' in posterior.")
    c_all = b_2.reshape(D)

    # sigma
    sigma_all = _stan_var_or_none("sigma")
    if sigma_all is None:
        raise ValueError("Missing 'sigma' in posterior.")
    sigma_all = sigma_all.reshape(D)

    # Detect model types (best-effort)
    is_gauss     = ("Gaussian" in model) or (_stan_var_or_none("tau") is None)
    is_dirichlet = ("Dirichlet" in model) or ("DST" in model)
    is_rhs       = ("Regularized Horseshoe" in model)

    # tau_w / tau_v
    if is_gauss:
        tau_w_all = np.ones(D)
        tau_v_all = np.ones(D)
    else:
        tau_w = _stan_var_or_none("tau")
        if tau_w is None:
            # Fallback if naming differs
            tau_w = _stan_var_or_none("tau_w")
        tau_w_all = tau_w.reshape(D)

        tau_v = _stan_var_or_none("tau_v")
        tau_v_all = np.ones(D) if tau_v is None else tau_v.reshape(D)

    # === Lambda extraction ===
    # (1) Effective lambda (regularized)
    lam_eff = None
    if not is_gauss:
        for nm in lambda_effective_candidates:
            arr = _stan_var_or_none(nm)
            if arr is not None:
                lam_eff = _coerce_DHp(arr, D, H, p)
                break

    # (2) Raw lambda (half-Cauchy)
    lam_raw = None
    if not is_gauss:
        for nm in lambda_raw_candidates:
            arr = _stan_var_or_none(nm)
            if arr is not None:
                lam_raw = _coerce_DHp(arr, D, H, p)
                break

    # (3) Optional Dirichlet multiplier phi_data
    phi_hp = None
    if include_phi_for_dirichlet and is_dirichlet:
        phi_arr = _stan_var_or_none(phi_name)
        if phi_arr is not None:
            phi_hp = _coerce_DHp(phi_arr, D, H, p)

    # Build the "lambda_all" to return per lambda_kind
    ones_DHp = np.ones((D, H, p))

    def _with_phi(lam):
        if lam is None:
            return None
        return lam * (phi_hp if phi_hp is not None else 1.0)

    # Defaults for Gaussian or missing variables
    lambda_eff_all = None
    lambda_raw_all = None

    if is_gauss:
        lambda_eff_all = ones_DHp
        lambda_raw_all = ones_DHp
    else:
        # effective
        if lam_eff is not None:
            lambda_eff_all = _with_phi(lam_eff) if is_dirichlet else lam_eff
        else:
            # if effective is missing, fall back gracefully
            lambda_eff_all = _with_phi(lam_raw) if (is_dirichlet and lam_raw is not None) else (lam_raw if lam_raw is not None else ones_DHp)

        # raw
        if lam_raw is not None:
            lambda_raw_all = lam_raw
        else:
            # if raw not present, fall back to effective or ones
            lambda_raw_all = lam_eff if lam_eff is not None else ones_DHp

    # === Return ===
    if lambda_kind == "effective":
        return W_all, b_all, v_all, c_all, sigma_all, tau_w_all, tau_v_all, lambda_eff_all
    elif lambda_kind == "raw":
        return W_all, b_all, v_all, c_all, sigma_all, tau_w_all, tau_v_all, lambda_raw_all
    elif lambda_kind == "both":
        # returns 9 items (adds lambda_raw_all at the end)
        return W_all, b_all, v_all, c_all, sigma_all, tau_w_all, tau_v_all, lambda_eff_all, lambda_raw_all
    else:
        raise ValueError("lambda_kind must be one of {'effective','raw','both'}.")


# ------- Knyt alt sammen -------

def compute_shrinkage_for_W_block(
    X: np.ndarray,
    W0: np.ndarray, b0: np.ndarray, v0: np.ndarray,
    noise: float, tau_w: float, tau_v: float,
    lambda_tilde: np.ndarray,
    activation: str = "tanh",
    include_b1_in_Sigma: bool = True,
    include_b2_in_Sigma: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Returnerer (R, P, S, Sigma_y) der R = (P+S)^{-1} P for W-blokken.
    """
    Phi_mat, JW, Jb1, Jb2 = build_hidden_and_jacobian_W(X, W0, b0, v0, activation=activation)  # (n,H), (n,Hp)
    #Sigma_y = build_Sigma_y(Phi_mat, tau_v=tau_v, J_b1=Jb1, J_b2=Jb2, noise=noise)                       # (n,n)
    Sigma_y = build_Sigma_y(
        Phi_mat,
        tau_v=tau_v,
        noise=noise,
        J_b1=Jb1,
        J_b2=Jb2,
        include_b1=include_b1_in_Sigma,
        include_b2=include_b2_in_Sigma,
    )
    P = build_P_from_lambda_tau(lambda_tilde, tau_w=tau_w)                        # (Hp,Hp)
    S = build_S(JW, Sigma_y)                                                      # (Hp,Hp)
    R = shrinkage_matrix_stable(P, S)                                                    # (Hp,Hp)
    return R, P, S, Sigma_y, JW, Phi_mat


def compute_shrinkage(
    X,
    W_all, b_all, v_all,          # (D,H,p), (D,H), (D,H)
    sigma_all, tau_w_all, tau_v_all,  # (D,), (D,), (D,)
    lambda_all,                   # (D,H,p)
    activation="tanh",
    return_mats=True,             # set False if you only want summaries
    include_b1_in_Sigma: bool = True,
    include_b2_in_Sigma: bool = True,
):
    """
    Loop over draws and compute R=(P+S)^{-1}P per draw using your single-draw function.
    Returns:
      R_stack : (D, N, N) with N=H*p  (if return_mats=True, else None)
      r_eigs  : (D, N)  sorted eigenvalues in [0,1]
      df_eff  : (D,)    effective dof = tr(I-R) = N - tr(R)
    """
    D, H, p = W_all.shape
    N = H * p

    R_stack = np.empty((D, N, N)) if return_mats else None
    S_stack = np.empty((D, N, N)) if return_mats else None
    P_stack = np.empty((D, N, N)) if return_mats else None
    G_stack = np.empty((D, N, N)) if return_mats else None
    shrink_stack= np.empty((D, N, N)) if return_mats else None
    r_eigs  = np.empty((D, N))
    df_eff  = np.empty(D)

    for d in range(D):
        R, P, S, Sigma_y, _, _ = compute_shrinkage_for_W_block(
            X=X,
            W0=W_all[d],
            b0=b_all[d],
            v0=v_all[d],
            noise=float(sigma_all[d]),
            tau_w=float(tau_w_all[d]),
            tau_v=float(tau_v_all[d]),
            lambda_tilde=lambda_all[d],
            activation=activation,
            include_b1_in_Sigma=include_b1_in_Sigma,
            include_b2_in_Sigma=include_b2_in_Sigma,
        )
        p = np.diag(P)                       
        P_inv_sqrt = np.diag(1.0/np.sqrt(p))         
        G = P_inv_sqrt @ S @ P_inv_sqrt 
        I = np.identity(N)
        shrink_mat = np.linalg.inv(I + G)@G

        if return_mats:
            R_stack[d] = R
            S_stack[d] = S
            P_stack[d] = P
            G_stack[d] = G
            shrink_stack[d] = shrink_mat
        


        r, df = shrinkage_eigs_and_df(P, S)
        r_eigs[d] = np.sort(r)
        df_eff[d] = df

    return R_stack, S_stack, P_stack, G_stack, shrink_stack, r_eigs, df_eff


# ---------- Minimal kjøreeksempel ----------
#draw = 0
#W, b1, v, b2, noise, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Gaussian tanh')
W, b1, v, b2, noise, tau_w, tau_v, lambda_eff = extract_model_draws(
    posterior_N100_fits, model='Gaussian tanh'
)
R_gauss, S_gauss, P_gauss, G_gauss, shrink_gauss, eigs_gauss, df_gauss = compute_shrinkage(
    X, W, b1, v, noise, tau_w, tau_v, lambda_eff,
    activation="tanh",
    include_b1_in_Sigma=False,
    include_b2_in_Sigma=False,
)

#W, b1, v, b2, noise, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Regularized Horseshoe tanh')
W, b1, v, b2, noise, tau_w, tau_v, lambda_eff = extract_model_draws(
    posterior_N100_fits, model='Regularized Horseshoe tanh'
)
# YOU CAN ALSO LOOK AT THE RAW LAMBDA VALUES BY RUNNING:
# W, b1, v, b2, noise, tau_w, tau_v, lambda_raw = extract_model_draws(
#     posterior_N100_fits, model='Regularized Horseshoe tanh', lambda_kind='raw'
# )
R_RHS, S_RHS, P_RHS, G_RHS, shrink_RHS, eigs_RHS, df_eff_RHS = compute_shrinkage(
    X, W, b1, v, noise, tau_w, tau_v, lambda_eff,
    activation="tanh",
    include_b1_in_Sigma=False,
    include_b2_in_Sigma=False,
)
#W, b1, v, b2, noise, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Dirichlet Horseshoe tanh')
W, b1, v, b2, noise, tau_w, tau_v, lambda_eff = extract_model_draws(
    posterior_N100_fits, model='Dirichlet Horseshoe tanh'
)
R_DHS, S_DHS, P_DHS, G_DHS, shrink_DHS, eigs_DHS, df_eff_DHS = compute_shrinkage(
    X, W, b1, v, noise, tau_w, tau_v, lambda_eff,
    activation="tanh",
    include_b1_in_Sigma=False,
    include_b2_in_Sigma=False,
)
#W, b1, v, b2, noise, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Dirichlet Student T tanh')
W, b1, v, b2, noise, tau_w, tau_v, lambda_eff = extract_model_draws(
    posterior_N100_fits, model='Dirichlet Student T tanh'
)
R_DST, S_DST, P_DST, G_DST, shrink_DST, eigs_DST, df_eff_DST = compute_shrinkage(
    X, W, b1, v, noise, tau_w, tau_v, lambda_eff,
    activation="tanh",
    include_b1_in_Sigma=False,
    include_b2_in_Sigma=False,
)

In [46]:

import matplotlib.pyplot as plt
from matplotlib.colors import SymLogNorm, Normalize, TwoSlopeNorm
import numpy as np


def add_block_grid(ax, H, p, color="w", lw=0.5):
    Hp = H*p
    for h in range(1, H):
        k = h*p
        ax.axhline(k-0.5, color=color, lw=lw)
        ax.axvline(k-0.5, color=color, lw=lw)

def visualize_models(
    matrices, names, H=16, p=10, use_abs=False, cmap="magma",
    q_low=0.05, q_high=0.95
):
    """
    Viser heatmaps av matriser med felles, robust fargeskala:
    vmin = q_low-kvantilen over ALLE matrisene
    vmax = q_high-kvantilen over ALLE matrisene
    Verdier utenfor [vmin, vmax] klippes til endene.
    """
    # valgfritt absoluttbeløp
    mats = [np.abs(M) if use_abs else M for M in matrices]

    # Samle alle endelige verdier
    all_vals = np.concatenate([M[np.isfinite(M)].ravel() for M in mats]) if mats else np.array([])

    if all_vals.size == 0:
        vmin, vmax = -1.0, 1.0
    else:
        vmin = float(np.quantile(all_vals, q_low))
        vmax = float(np.quantile(all_vals, q_high))
        if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
            # fallback hvis alt er likt/NaN
            m = float(np.nanmean(all_vals)) if np.isfinite(np.nanmean(all_vals)) else 0.0
            vmin, vmax = m - 1.0, m + 1.0

    norm = Normalize(vmin=vmin, vmax=vmax)
    #norm = TwoSlopeNorm(vmin=vmin, vcenter=0.0, vmax=vmax)

    # Figur
    fig, axes = plt.subplots(2, 2, figsize=(10, 10), dpi=150, constrained_layout=True)
    axes = axes.ravel()

    im = None
    for ax, M, title in zip(axes, mats, names):
        Mplot = np.clip(M, vmin, vmax)  # klipp utfor intervallet
        im = ax.imshow(Mplot, aspect='equal', interpolation='nearest', cmap=cmap, norm=norm)
        add_block_grid(ax, H, p)
        ax.set_title(title)
        ax.set_xlabel("Columns")
        ax.set_ylabel("Rows")

    # Felles colorbar
    if im is not None:
        fig.colorbar(im, ax=axes.tolist(), label=("|Value|" if use_abs else "Value"))

    # (valgfritt) print hva som faktisk ble brukt
    print(f"Felles kvantiler: vmin (q={q_low}) = {vmin:.6g}, vmax (q={q_high}) = {vmax:.6g}")

    plt.show()


In [None]:
# 3) Plot S
# matrices_S = [
#     S_gauss.mean(axis=0),
#     S_RHS.mean(axis=0),
#     S_DHS.mean(axis=0),
#     S_DST.mean(axis=0),
# ]

# matrices_S = [
#     S_gauss.mean(axis=0),
#     S_RHS.mean(axis=0),
#     S_DHS.mean(axis=0),
#     S_DST.mean(axis=0),
# ]

matrices_S = [
    log_euclidean_median(S_gauss),
    log_euclidean_median(S_RHS),
    log_euclidean_median(S_DHS),
    log_euclidean_median(S_DST),
]
names_S = ["S (Gauss)", "S (RHS)", "S (DHS)", "S (DST)"]

visualize_models(matrices_S, names_S, H=16, p=10, use_abs=False, cmap="magma")


In [None]:
# 1) Plot G
# matrices_G = [
#     G_gauss.mean(axis=0),
#     G_RHS.mean(axis=0),
#     G_DHS.mean(axis=0),
#     G_DST.mean(axis=0),
# ]

# matrices_G = [
#     np.median(G_gauss, axis=0),
#     np.median(G_RHS, axis=0),
#     np.median(G_DHS, axis=0),
#     np.median(G_DST, axis=0),
# ]

matrices_G = [
    log_euclidean_median(G_gauss),
    log_euclidean_median(G_RHS),
    log_euclidean_median(G_DHS),
    log_euclidean_median(G_DST),
]

names_G = ["G (Gauss)", "G (RHS)", "G (DHS)", "G (DST)"]

visualize_models(matrices_G, names_G, H=16, p=10, use_abs=False, cmap="magma")


In [None]:
# 2) Plot (I + G)^{-1} G   (shrinkage in whitened space)
# matrices_shrink = [
#     shrink_gauss.mean(axis=0),
#     shrink_RHS.mean(axis=0),
#     shrink_DHS.mean(axis=0),
#     shrink_DST.mean(axis=0),
# ]

# matrices_shrink = [
#     np.median(shrink_gauss, axis=0),
#     np.median(shrink_RHS, axis=0),
#     np.median(shrink_DHS, axis=0),
#     np.median(shrink_DST, axis=0),
# ]

matrices_shrink = [
    log_euclidean_median(shrink_gauss),
    log_euclidean_median(shrink_RHS),
    log_euclidean_median(shrink_DHS),
    log_euclidean_median(shrink_DST),
]

names_shrink = ["(I+G)^{-1}G (Gauss)", "(I+G)^{-1}G (RHS)", "(I+G)^{-1}G (DHS)", "(I+G)^{-1}G (DST)"]

visualize_models(matrices_shrink, names_shrink, H=16, p=10, use_abs=False, cmap="magma")


In [None]:

SP_inv_S_gauss = np.eye(16*10)[:, :] - R_gauss
SP_inv_S_RHS = np.eye(16*10)[:, :] - R_RHS
SP_inv_S_DHS = np.eye(16*10)[:, :] - R_DHS
SP_inv_S_DST = np.eye(16*10)[:, :] - R_DST
matrices_SP_inv_S = [
    log_euclidean_median(SP_inv_S_gauss),
    log_euclidean_median(SP_inv_S_RHS),
    log_euclidean_median(SP_inv_S_DHS),
    log_euclidean_median(SP_inv_S_DST)
]
names_SP_inv_S = ["(P+S)^{-1}S (Gauss)", "(P+S)^{-1}S (RHS)", "(P+S)^{-1}S (DHS)", "(P+S)^{-1}S (DST)"]

visualize_models(matrices_SP_inv_S, names_SP_inv_S, H=16, p=10, use_abs=False, cmap="magma")


In [11]:
# --- Helpers: log–Euclidean median for SPD stacks (D, N, N) ---
import numpy as np

def _symmetrize(M):
    return 0.5*(M + M.swapaxes(-1, -2))

def _spd_log(M, eps=1e-8):
    M = _symmetrize(M)
    w, U = np.linalg.eigh(M + eps*np.eye(M.shape[-1]))
    return U @ np.diag(np.log(np.clip(w, eps, None))) @ U.T

def _spd_exp(M):
    M = _symmetrize(M)
    w, U = np.linalg.eigh(M)
    return U @ np.diag(np.exp(w)) @ U.T

def log_euclidean_median(stack, eps=1e-8):
    """
    stack: (D, N, N), SPD/PSD. Uses elementwise median in log-domain.
    Returns: (N, N) SPD.
    """
    D, N, _ = stack.shape
    logs = np.empty_like(stack)
    for d in range(D):
        logs[d] = _spd_log(stack[d], eps=eps)
    med_log = np.median(logs, axis=0)
    return _spd_exp(med_log)


In [None]:
# --- Log–Euclidean medians for your shrink stacks + visualize ---
matrices_shrink = [
    log_euclidean_median(shrink_gauss),
    log_euclidean_median(shrink_RHS),
    log_euclidean_median(shrink_DHS),
    log_euclidean_median(shrink_DST),
]
names_shrink = ["(I+G)^{-1}G (Gauss)", "(I+G)^{-1}G (RHS)", "(I+G)^{-1}G (DHS)", "(I+G)^{-1}G (DST)"]

visualize_models(matrices_shrink, names_shrink, H=16, p=10, use_abs=False, cmap="magma")


In [13]:
SP_inv_S_gauss = np.eye(16*10)[:, :] - R_gauss
SP_inv_S_RHS = np.eye(16*10)[:, :] - R_RHS
SP_inv_S_DHS = np.eye(16*10)[:, :] - R_DHS
SP_inv_S_DST = np.eye(16*10)[:, :] - R_DST

In [14]:
# --- Traces as distributions (df_eff = tr(R) vs total shrinkage = tr(I-R)) ---
import matplotlib.pyplot as plt

# Effective dof: trace of (I+G)^{-1}G per draw
tr_R_gauss = np.trace(shrink_gauss, axis1=1, axis2=2)
tr_R_RHS   = np.trace(shrink_RHS,   axis1=1, axis2=2)
tr_R_DHS   = np.trace(shrink_DHS,   axis1=1, axis2=2)
tr_R_DST   = np.trace(shrink_DST,   axis1=1, axis2=2)

# If you also want “total shrinkage”, use your SP_inv_S_* stacks (I - R):
tr_SPinvS_gauss = np.trace(SP_inv_S_gauss, axis1=1, axis2=2)
tr_SPinvS_RHS   = np.trace(SP_inv_S_RHS,   axis1=1, axis2=2)
tr_SPinvS_DHS   = np.trace(SP_inv_S_DHS,   axis1=1, axis2=2)
tr_SPinvS_DST   = np.trace(SP_inv_S_DST,   axis1=1, axis2=2)



In [None]:
# Plot df_eff distributions
plt.figure(figsize=(8,4), dpi=150)
bins = 40
plt.hist(tr_R_gauss, bins=bins, alpha=0.5, label="Gauss")
plt.hist(tr_R_RHS,   bins=bins, alpha=0.5, label="RHS")
plt.hist(tr_R_DHS,   bins=bins, alpha=0.5, label="DHS")
plt.hist(tr_R_DST,   bins=bins, alpha=0.5, label="DST")
plt.xlabel("trace((I+G)^{-1}G)  [effective dof]")
plt.ylabel("count")
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(8,4), dpi=150)
bins = 40
plt.hist(tr_SPinvS_gauss, bins=bins, alpha=0.5, label="Gauss")
plt.hist(tr_SPinvS_RHS,   bins=bins, alpha=0.5, label="RHS")
plt.hist(tr_SPinvS_DHS,   bins=bins, alpha=0.5, label="DHS")
plt.hist(tr_SPinvS_DST,   bins=bins, alpha=0.5, label="DST")
plt.xlabel("trace((P+S)^{-1}S)  [dof]")
plt.ylabel("count")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# --- Median eigenvalue curve (with bands) for shrink stacks ---
import matplotlib.pyplot as plt

def median_eigcurve(stack, q_lo=0.1, q_hi=0.9):
    """
    stack: (D, N, N) of symmetric PSD matrices with eigenvalues in [0,1].
    Returns: dict with 'median', 'lo', 'hi' over the sorted eigenvalues (descending).
    """
    D, N, _ = stack.shape
    evals = np.empty((D, N))
    for d in range(D):
        w = np.linalg.eigvalsh(stack[d])
        evals[d] = np.sort(w)[::-1]  # descending
    med = np.median(evals, axis=0)
    lo  = np.quantile(evals, q_lo, axis=0)
    hi  = np.quantile(evals, q_hi, axis=0)
    return {"median": med, "lo": lo, "hi": hi}

curves = {
    "Gauss": median_eigcurve(shrink_gauss),
    "RHS":   median_eigcurve(shrink_RHS),
    "DHS":   median_eigcurve(shrink_DHS),
    "DST":   median_eigcurve(shrink_DST),
}

# Plot 2x2 small multiples
fig, axes = plt.subplots(2, 2, figsize=(8,6), dpi=150, constrained_layout=True)
axes = axes.ravel()
for ax, (name, c) in zip(axes, curves.items()):
    x = np.arange(1, len(c["median"])+1)
    ax.plot(x, c["median"], lw=1.8, label=f"{name} median")
    ax.fill_between(x, c["lo"], c["hi"], alpha=0.25, label=f"{name} {10}-{90}%", step=None)
    ax.set_title(name)
    ax.set_xlabel("eigenvalue rank")
    ax.set_ylabel("eigenvalue of (I+G)^{-1}G")
    ax.set_ylim(0, 1)
    ax.legend(loc="upper right", fontsize=8)
plt.show()


In [None]:
# --- Helper: blockwise normalized inner-product ratio for a stack of G ---
import numpy as np

def block_ratio_median(G_stack: np.ndarray, H: int, p: int, eps: float = 1e-12):
    """
    For each draw/d, computes R_d = |G_d| / sqrt(diag(G_d) diag(G_d)^T) within each (p x p) block h.
    Returns the elementwise median over draws (ignoring cross-block entries).
    
    G_stack : (D, N, N) with N = H*p
    Output  : (N, N) median ratio; outside-block entries set to 0 for convenience.
    """
    D, N, _ = G_stack.shape
    # block mask: True if (i,j) are in the same hidden-unit block
    idx = np.arange(N)
    h_idx = idx // p
    same_block = (h_idx[:, None] == h_idx[None, :])
    
    ratios = np.empty_like(G_stack)
    for d in range(D):
        G = 0.5 * (G_stack[d] + G_stack[d].T)                     # symmetrize
        dvec = np.clip(np.diag(G), eps, None)
        denom = np.sqrt(np.outer(dvec, dvec))
        R = np.divide(np.abs(G), denom, out=np.zeros_like(G), where=(denom > 0))
        # keep only within-block entries; mark cross-block as NaN (ignored in median)
        R_masked = np.where(same_block, R, np.nan)
        ratios[d] = R_masked

    # median over draws ignoring NaNs (i.e., cross-block)
    R_med = np.nanmedian(ratios, axis=0)
    # put zeros outside blocks for cleaner plotting
    R_med = np.where(same_block, R_med, 0.0)
    # exact ones on the diagonal by definition
    np.fill_diagonal(R_med, 1.0)
    return R_med

# --- Compute median ratio matrices for the four priors ---
H, p = 16, 10  # adjust if needed
ratio_Gauss = block_ratio_median(G_gauss, H=H, p=p)
ratio_RHS   = block_ratio_median(G_RHS,   H=H, p=p)
ratio_DHS   = block_ratio_median(G_DHS,   H=H, p=p)
ratio_DST   = block_ratio_median(G_DST,   H=H, p=p)


# --- (Optional) Scalar summaries: median off-diagonal ratio within blocks ---
def offdiag_block_stats(R_med: np.ndarray, H: int, p: int):
    N = H * p
    idx = np.arange(N)
    h_idx = idx // p
    same_block = (h_idx[:, None] == h_idx[None, :])
    offdiag = same_block & (~np.eye(N, dtype=bool))
    vals = R_med[offdiag]
    return {
        "median_offdiag": float(np.median(vals)),
        "q10_offdiag": float(np.quantile(vals, 0.10)),
        "q90_offdiag": float(np.quantile(vals, 0.90)),
    }

print("Gauss:", offdiag_block_stats(ratio_Gauss, H, p))
print("RHS:  ", offdiag_block_stats(ratio_RHS,   H, p))
print("DHS:  ", offdiag_block_stats(ratio_DHS,   H, p))
print("DST:  ", offdiag_block_stats(ratio_DST,   H, p))


In [None]:
phi_samples = posterior_N100_fits['Dirichlet Horseshoe tanh']['posterior'].stan_variable("phi_data")
phi_samples[2, 2, :]

## Bounds

In [80]:
# --- (1) Bound Σ_y^{-1}: compute c_min, c_max for a single draw ---
import numpy as np

def sigma_inverse_bounds(Phi_mat, tau_v, noise, J_b1=None, J_b2=None,
                         include_b1=True, include_b2=True):
    """
    Constructs Q so that Σ_y = σ^2 I + Q Q^T, then returns:
      c_min = 1 / (σ^2 + ||Q||_2^2),   c_max = 1 / σ^2.
    Here Q = [ J_b1,  τ_v * Φ,  (J_b2) ] with columns included per flags.
    """
    cols = [tau_v * Phi_mat]  # Φ always contributes with τ_v
    if include_b1 and (J_b1 is not None):
        cols.insert(0, J_b1)  # [J_b1, τ_v Φ, ...]
    if include_b2 and (J_b2 is not None):
        cols.append(J_b2[:, None])  # add as a column

    Q = np.concatenate(cols, axis=1) if len(cols) > 1 else cols[0]
    # spectral norm of Q (largest singular value)
    smax = np.linalg.svd(Q, compute_uv=False)[0]
    c_min = 1.0 / (noise**2 + smax**2)
    c_max = 1.0 / (noise**2)
    return c_min, c_max, smax

# --- (2) Bound S and G via c_min, c_max; also get α-eigs of A = P^{-1/2} J^T J P^{-1/2} ---
def bound_S_and_G(JW, P, c_min, c_max):
    """
    S = J^T Σ_y^{-1} J, with c_min*J^T J ⪯ S ⪯ c_max*J^T J.
    Define A = P^{-1/2} J^T J P^{-1/2}; then c_min*A ⪯ G ⪯ c_max*A, where G=P^{-1/2} S P^{-1/2}.
    Returns:
      JJ      : J^T J
      A       : P^{-1/2} (J^T J) P^{-1/2}
      alpha   : eigvals(A) sorted ascending
      S_lo/up : lower/upper Loewner bounds on S
      G_lo/up : lower/upper Loewner bounds on G
    """
    JJ = JW.T @ JW
    d = np.diag(P).astype(float)
    Pinvhalf = np.diag(1.0 / np.sqrt(np.maximum(d, 1e-12)))
    A = Pinvhalf @ JJ @ Pinvhalf
    alpha = np.linalg.eigvalsh(A)  # >=0, ascending

    S_lower = c_min * JJ
    S_upper = c_max * JJ
    G_lower = c_min * A
    G_upper = c_max * A
    return JJ, A, alpha, S_lower, S_upper, G_lower, G_upper

# --- (3) Eigenvalue bands for (I+G)^{-1} G and trace/df bounds from α-eigs and c_min/c_max ---
def shrinkage_bands_from_alpha(alpha, c_min, c_max):
    """
    For f(t)=t/(1+t), eigenvalues of (I+G)^{-1}G satisfy:
       f(c_min*alpha_i) <= λ_i <= f(c_max*alpha_i)
    Returns lower/upper arrays (ascending to match alpha).
    """
    f = lambda t: t / (1.0 + t)
    lam_lo = f(c_min * alpha)
    lam_hi = f(c_max * alpha)
    return lam_lo, lam_hi

def df_bounds_from_alpha(alpha, c_min, c_max):
    """Lower/upper bounds on df_eff = tr( (I+G)^{-1}G )."""
    lam_lo, lam_hi = shrinkage_bands_from_alpha(alpha, c_min, c_max)
    return float(np.sum(lam_lo)), float(np.sum(lam_hi))

# --- (4) Bands over ALL draws (median band to overlay with your median eigenvalue curve) ---
def eigen_bands_over_draws(
    X, W_all, b_all, v_all, sigma_all, tau_w_all, tau_v_all, lambda_all,
    activation="tanh", include_b1_in_Sigma=True, include_b2_in_Sigma=True
):
    """
    For each draw:
      - build JW, Φ, J_b1, J_b2, get c_min/c_max,
      - build P and α-eigenvalues of A,
      - compute lower/upper shrinkage eigenvalue bands.
    Returns medians (and optional quantiles) across draws:
      lam_lo_med, lam_hi_med, df_lo_med, df_hi_med
    """
    D, H, p = W_all.shape
    N = H * p
    lam_lo_stack = np.empty((D, N))
    lam_hi_stack = np.empty((D, N))
    df_lo = np.empty(D)
    df_hi = np.empty(D)

    for d in range(D):
        Phi_mat, JW, Jb1, Jb2 = build_hidden_and_jacobian_W(
            X, W_all[d], b_all[d], v_all[d], activation=activation
        )
        c_min, c_max, _ = sigma_inverse_bounds(
            Phi_mat, tau_v=float(tau_v_all[d]), noise=float(sigma_all[d]),
            J_b1=(Jb1 if include_b1_in_Sigma else None),
            J_b2=(Jb2 if include_b2_in_Sigma else None),
            include_b1=include_b1_in_Sigma, include_b2=include_b2_in_Sigma
        )
        P = build_P_from_lambda_tau(lambda_all[d], tau_w=float(tau_w_all[d]))

        _, _, alpha, _, _, _, _ = bound_S_and_G(JW, P, c_min, c_max)
        lam_lo, lam_hi = shrinkage_bands_from_alpha(alpha, c_min, c_max)
        lam_lo_stack[d] = lam_lo
        lam_hi_stack[d] = lam_hi
        df_lo[d], df_hi[d] = df_bounds_from_alpha(alpha, c_min, c_max)

    # Medians across draws (coordinate-wise)
    lam_lo_med = np.median(lam_lo_stack, axis=0)
    lam_hi_med = np.median(lam_hi_stack, axis=0)
    df_lo_med  = float(np.median(df_lo))
    df_hi_med  = float(np.median(df_hi))

    # (Optional) also return 10/90% ribbons if you want
    lam_lo_q10 = np.quantile(lam_lo_stack, 0.10, axis=0)
    lam_hi_q90 = np.quantile(lam_hi_stack, 0.90, axis=0)

    summary = {
        "df_lo_med": df_lo_med,
        "df_hi_med": df_hi_med,
        "df_lo_q10": float(np.quantile(df_lo, 0.10)),
        "df_hi_q90": float(np.quantile(df_hi, 0.90)),
    }
    return lam_lo_med, lam_hi_med, lam_lo_q10, lam_hi_q90, summary



In [None]:
# --- (5) Quick demo on one model (e.g., RHS): print Σ_y^{-1} bounds and df bounds, then plot ---
import matplotlib.pyplot as plt

# Example: use your RHS arrays (replace with Gauss/DHS/DST as needed)
W_all, b_all, v_all, c_all, sigma_all, tau_w_all, tau_v_all, lambda_all = \
    extract_model_draws(posterior_N100_fits, model='Gaussian tanh', lambda_kind="effective")

# (a) Compute median eigenvalue bands across draws
lam_lo_med, lam_hi_med, lam_lo_q10, lam_hi_q90, band_summ = eigen_bands_over_draws(
    X, W_all, b_all, v_all, sigma_all, tau_w_all, tau_v_all, lambda_all,
    activation="tanh", include_b1_in_Sigma=True, include_b2_in_Sigma=True
)

print("Median df_eff bounds (Gaussian):")
print(f"  lower  (Σ_y^{-1}≈1/(σ^2+||Q||^2)): {band_summ['df_lo_med']:.3f}")
print(f"  upper  (Σ_y^{-1}≈1/σ^2):           {band_summ['df_hi_med']:.3f}")

# (b) If you ALREADY have the median eigenvalue curve of (I+G)^{-1}G (call it eig_med_shrink),
#     great; otherwise compute it from shrink_stack (from compute_shrinkage)
try:
    nix #median_eigcurve(shrink_gauss) #eig_med_shrink
except NameError:
    # build shrink_stack and take median eigenvalues (ascending to match bands)
    _, _, _, _, shrink_stack, _, _ = compute_shrinkage(
        X, W_all, b_all, v_all, sigma_all, tau_w_all, tau_v_all, lambda_all,
        activation="tanh", include_b1_in_Sigma=True, include_b2_in_Sigma=True
    )
    D, N, _ = shrink_stack.shape
    eigs = np.empty((D, N))
    for d in range(D):
        eigs[d] = np.linalg.eigvalsh(shrink_stack[d])  # ascending
    eig_med_shrink = np.median(eigs, axis=0)

# (c) Overlay: median band vs. median eigenvalue curve
x = np.arange(1, lam_lo_med.size + 1)
plt.figure(figsize=(7.2, 4.2), dpi=150)
plt.fill_between(x, lam_lo_med, lam_hi_med, alpha=0.25, label="theory band (median)")
plt.plot(x, eig_med_shrink, lw=1.8, label="empirical median eigenvalue")
# (optional) add thinner 10–90% theoretical ribbon
plt.fill_between(x, lam_lo_q10, lam_hi_q90, alpha=0.15, label="theory band (10–90%)")

plt.xlabel("eigenvalue index (ascending)")
plt.ylabel(r"eigenvalue of $(I+G)^{-1}G$")
plt.title("Shrinkage eigenvalue bands vs. empirical curve (Gaussian)")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
# --- (5) Quick demo on one model (e.g., RHS): print Σ_y^{-1} bounds and df bounds, then plot ---
import matplotlib.pyplot as plt

# Example: use your RHS arrays (replace with Gauss/DHS/DST as needed)
W_all, b_all, v_all, c_all, sigma_all, tau_w_all, tau_v_all, lambda_all = \
    extract_model_draws(posterior_N100_fits, model='Regularized Horseshoe tanh', lambda_kind="effective")

# (a) Compute median eigenvalue bands across draws
lam_lo_med, lam_hi_med, lam_lo_q10, lam_hi_q90, band_summ = eigen_bands_over_draws(
    X, W_all, b_all, v_all, sigma_all, tau_w_all, tau_v_all, lambda_all,
    activation="tanh", include_b1_in_Sigma=True, include_b2_in_Sigma=True
)

print("Median df_eff bounds (RHS):")
print(f"  lower  (Σ_y^{-1}≈1/(σ^2+||Q||^2)): {band_summ['df_lo_med']:.3f}")
print(f"  upper  (Σ_y^{-1}≈1/σ^2):           {band_summ['df_hi_med']:.3f}")

# (b) If you ALREADY have the median eigenvalue curve of (I+G)^{-1}G (call it eig_med_shrink),
#     great; otherwise compute it from shrink_stack (from compute_shrinkage)
try:
    nix #median_eigcurve(shrink_RHS) #eig_med_shrink
except NameError:
    # build shrink_stack and take median eigenvalues (ascending to match bands)
    _, _, _, _, shrink_stack, _, _ = compute_shrinkage(
        X, W_all, b_all, v_all, sigma_all, tau_w_all, tau_v_all, lambda_all,
        activation="tanh", include_b1_in_Sigma=True, include_b2_in_Sigma=True
    )
    D, N, _ = shrink_stack.shape
    eigs = np.empty((D, N))
    for d in range(D):
        eigs[d] = np.linalg.eigvalsh(shrink_stack[d])  # ascending
    eig_med_shrink = np.median(eigs, axis=0)

# (c) Overlay: median band vs. median eigenvalue curve
x = np.arange(1, lam_lo_med.size + 1)
plt.figure(figsize=(7.2, 4.2), dpi=150)
plt.fill_between(x, lam_lo_med, lam_hi_med, alpha=0.25, label="theory band (median)")
plt.plot(x, eig_med_shrink, lw=1.8, label="empirical median eigenvalue")
# (optional) add thinner 10–90% theoretical ribbon
plt.fill_between(x, lam_lo_q10, lam_hi_q90, alpha=0.15, label="theory band (10–90%)")

plt.xlabel("eigenvalue index (ascending)")
plt.ylabel(r"eigenvalue of $(I+G)^{-1}G$")
plt.title("Shrinkage eigenvalue bands vs. empirical curve (RHS)")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
# --- (5) Quick demo on one model (e.g., RHS): print Σ_y^{-1} bounds and df bounds, then plot ---
import matplotlib.pyplot as plt

# Example: use your RHS arrays (replace with Gauss/DHS/DST as needed)
W_all, b_all, v_all, c_all, sigma_all, tau_w_all, tau_v_all, lambda_all = \
    extract_model_draws(posterior_N100_fits, model='Dirichlet Horseshoe tanh', lambda_kind="effective")

# (a) Compute median eigenvalue bands across draws
lam_lo_med, lam_hi_med, lam_lo_q10, lam_hi_q90, band_summ = eigen_bands_over_draws(
    X, W_all, b_all, v_all, sigma_all, tau_w_all, tau_v_all, lambda_all,
    activation="tanh", include_b1_in_Sigma=True, include_b2_in_Sigma=True
)

print("Median df_eff bounds (DHS):")
print(f"  lower  (Σ_y^{-1}≈1/(σ^2+||Q||^2)): {band_summ['df_lo_med']:.3f}")
print(f"  upper  (Σ_y^{-1}≈1/σ^2):           {band_summ['df_hi_med']:.3f}")

# (b) If you ALREADY have the median eigenvalue curve of (I+G)^{-1}G (call it eig_med_shrink),
#     great; otherwise compute it from shrink_stack (from compute_shrinkage)
try:
    nix #eig_med_shrink #median_eigcurve(shrink_DHS) 
except NameError:
    # build shrink_stack and take median eigenvalues (ascending to match bands)
    _, _, _, _, shrink_stack, _, _ = compute_shrinkage(
        X, W_all, b_all, v_all, sigma_all, tau_w_all, tau_v_all, lambda_all,
        activation="tanh", include_b1_in_Sigma=True, include_b2_in_Sigma=True
    )
    D, N, _ = shrink_stack.shape
    eigs = np.empty((D, N))
    for d in range(D):
        eigs[d] = np.linalg.eigvalsh(shrink_stack[d])  # ascending
    eig_med_shrink = np.median(eigs, axis=0)

# (c) Overlay: median band vs. median eigenvalue curve
x = np.arange(1, lam_lo_med.size + 1)
plt.figure(figsize=(7.2, 4.2), dpi=150)
plt.fill_between(x, lam_lo_med, lam_hi_med, alpha=0.25, label="theory band (median)")
plt.plot(x, eig_med_shrink, lw=1.8, label="empirical median eigenvalue")
# (optional) add thinner 10–90% theoretical ribbon
plt.fill_between(x, lam_lo_q10, lam_hi_q90, alpha=0.15, label="theory band (10–90%)")

plt.xlabel("eigenvalue index (ascending)")
plt.ylabel(r"eigenvalue of $(I+G)^{-1}G$")
plt.title("Shrinkage eigenvalue bands vs. empirical curve (DHS)")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
# --- (5) Quick demo on one model (e.g., RHS): print Σ_y^{-1} bounds and df bounds, then plot ---
import matplotlib.pyplot as plt

# Example: use your RHS arrays (replace with Gauss/DHS/DST as needed)
W_all, b_all, v_all, c_all, sigma_all, tau_w_all, tau_v_all, lambda_all = \
    extract_model_draws(posterior_N100_fits, model='Dirichlet Student T tanh', lambda_kind="effective")

# (a) Compute median eigenvalue bands across draws
lam_lo_med, lam_hi_med, lam_lo_q10, lam_hi_q90, band_summ = eigen_bands_over_draws(
    X, W_all, b_all, v_all, sigma_all, tau_w_all, tau_v_all, lambda_all,
    activation="tanh", include_b1_in_Sigma=True, include_b2_in_Sigma=True
)

print("Median df_eff bounds (DST):")
print(f"  lower  (Σ_y^{-1}≈1/(σ^2+||Q||^2)): {band_summ['df_lo_med']:.3f}")
print(f"  upper  (Σ_y^{-1}≈1/σ^2):           {band_summ['df_hi_med']:.3f}")

# (b) If you ALREADY have the median eigenvalue curve of (I+G)^{-1}G (call it eig_med_shrink),
#     great; otherwise compute it from shrink_stack (from compute_shrinkage)
try:
    nix #median_eigcurve(shrink_DST) #eig_med_shrink
except NameError:
    # build shrink_stack and take median eigenvalues (ascending to match bands)
    _, _, _, _, shrink_stack, _, _ = compute_shrinkage(
        X, W_all, b_all, v_all, sigma_all, tau_w_all, tau_v_all, lambda_all,
        activation="tanh", include_b1_in_Sigma=True, include_b2_in_Sigma=True
    )
    D, N, _ = shrink_stack.shape
    eigs = np.empty((D, N))
    for d in range(D):
        eigs[d] = np.linalg.eigvalsh(shrink_stack[d])  # ascending
    eig_med_shrink = np.median(eigs, axis=0)

# (c) Overlay: median band vs. median eigenvalue curve
x = np.arange(1, lam_lo_med.size + 1)
plt.figure(figsize=(7.2, 4.2), dpi=150)
plt.fill_between(x, lam_lo_med, lam_hi_med, alpha=0.25, label="theory band (median)")
plt.plot(x, eig_med_shrink, lw=1.8, label="empirical median eigenvalue")
# (optional) add thinner 10–90% theoretical ribbon
plt.fill_between(x, lam_lo_q10, lam_hi_q90, alpha=0.15, label="theory band (10–90%)")

plt.xlabel("eigenvalue index (ascending)")
plt.ylabel(r"eigenvalue of $(I+G)^{-1}G$")
plt.title("Shrinkage eigenvalue bands vs. empirical curve (DST)")
plt.legend()
plt.tight_layout()
plt.show()


## Build linearized $\bar{w}$

In [110]:
def compute_linearized_mean(
    X, y,
    W_all, b1_all, b2_all, v_all,          # (D,H,p), (D,H), (D,), (D,H)
    sigma_all, tau_w_all, tau_v_all,       # (D,), (D,), (D,)
    lambda_all,                            # (D,H,p)
    activation="tanh",
    return_mats=True,
    include_b1_in_Sigma: bool = True,      # pass-through to your Σ_y builder (if used downstream)
    include_b2_in_Sigma: bool = True,      # pass-through to your Σ_y builder (if used downstream)
    include_b1_in_y_star: bool = True,     # NEW: include J_b1 b1,0 in y*
    #include_b2_in_y_star: bool = True,     # NEW: include J_b2 b2,0 in y*
):
    """
    Per draw d, compute:
      - R_d, P_d, S_d, Sigma_y_d, J_w,d (and J_b,d if available) from your local function
      - y*_d = (y - b2_d*1) + J_w,d @ vec(W0_d) [+ J_b,d @ b1,0,d  if include_b1_in_y_star]
      - g_d  = J_w,d^T (Sigma_y_d^{-1} y*_d)              [via solve]
      - bar_w_d = (P_d + S_d)^{-1} g_d                    [via solve]
    Returns:
      R_stack    : (D, N, N)  (None if return_mats=False)
      w_bar_stack: (D, N)
    """
    import numpy as np

    D, H, p = W_all.shape
    N = H * p
    n = y.shape[0]

    R_stack = np.empty((D, N, N)) if return_mats else None
    w_bar_stack = np.empty((D, N))

    y = np.asarray(y, dtype=float).reshape(n)

    for d in range(D):
        # Call your single-draw routine. It may or may not return J_b.
        res = compute_shrinkage_for_W_block(
            X=X,
            W0=W_all[d],
            b0=b1_all[d],
            v0=v_all[d],
            noise=float(sigma_all[d]),
            tau_w=float(tau_w_all[d]),
            tau_v=float(tau_v_all[d]),
            lambda_tilde=lambda_all[d],
            activation=activation,
            include_b1_in_Sigma=include_b1_in_Sigma,
            include_b2_in_Sigma=include_b2_in_Sigma,
        )

        # Accept both signatures:
        # (R, P, S, Sigma_y, JW, Phi_mat, ...)   or   (R, P, S, Sigma_y, JW, Jb, Phi_mat, ...)
        if len(res) >= 7 and res[5].ndim == 2 and res[5].shape[1] == H:
            R, P, S, Sigma_y, JW, Jb, _ = res[:7]
        else:
            R, P, S, Sigma_y, JW, _ = res[:6]
            # Fallback: get J_b directly (cheap compared to solves)
            _, _, Jb, _ = build_hidden_and_jacobian_W(X, W_all[d], b1_all[d], v_all[d], activation=activation)

        # --- Build y* = (y - b2*1) + J_w w1,0 [+ J_b b1,0] ---
        z = y - float(b2_all[d])
        w0_vec = W_all[d].reshape(-1)            # (N,)
        y_star = z + JW @ w0_vec                 # (n,)
        if include_b1_in_y_star:
            y_star = y_star + (Jb @ b1_all[d])   # add J_b b1,0  (n,)
        #if include_b2_in_y_star:
        #    y_star = y_star + float(b2_all[d])*np.ones(n)   # add J_b b1,0  (n,)

        # --- g = J_w^T Σ_y^{-1} y*  (stable solve) ---
        r = np.linalg.solve(Sigma_y, y_star)     # (n,)
        g = JW.T @ r                              # (N,)

        # --- \bar w = (P + S)^{-1} g  (no explicit inverse) ---
        bar_w = np.linalg.solve(P + S, g)        # (N,)

        if return_mats:
            R_stack[d] = R

        w_bar_stack[d] = bar_w

    return R_stack, w_bar_stack

W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Gaussian tanh')

R_stack_gauss, w_bar_stack_gauss = compute_linearized_mean(X, y, W, b1, b2, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")

W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Regularized Horseshoe tanh')

R_stack_RHS, w_bar_stack_RHS = compute_linearized_mean(X, y, W, b1, b2, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")

W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Dirichlet Horseshoe tanh')

R_stack_DHS, w_bar_stack_DHS = compute_linearized_mean(X, y, W, b1, b2, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")

W1, b1, W2, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Dirichlet Student T tanh')

R_stack_DST, w_bar_stack_DST = compute_linearized_mean(X, y, W, b1, b2, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")



In [111]:
W_all_gauss = posterior_N100_fits['Gaussian tanh']['posterior'].stan_variable("W_1")
v_all_gauss = posterior_N100_fits['Gaussian tanh']['posterior'].stan_variable("W_L")

W_all_RHS = posterior_N100_fits['Regularized Horseshoe tanh']['posterior'].stan_variable("W_1")
v_all_RHS = posterior_N100_fits['Regularized Horseshoe tanh']['posterior'].stan_variable("W_L")

W_all_DHS = posterior_N100_fits['Dirichlet Horseshoe tanh']['posterior'].stan_variable("W_1")
v_all_DHS = posterior_N100_fits['Dirichlet Horseshoe tanh']['posterior'].stan_variable("W_L")

W_all_DST = posterior_N100_fits['Dirichlet Student T tanh']['posterior'].stan_variable("W_1")
v_all_DST = posterior_N100_fits['Dirichlet Student T tanh']['posterior'].stan_variable("W_L")

In [None]:
def align_and_compare(W_all, v_all, w_bar_stack, sort_key="abs_v"):
    """
    Align signs & permutations across draws before comparing linearized mean with posterior mean.

    Inputs
    ------
    W_all        : array-like, shape (D, H, p) or (D, p, H) or with stray singleton dims.
    v_all        : array-like, shape (D, H) or (D, H, 1) or similar (length H per draw).
    w_bar_stack  : array-like, shape (D, H*p) OR (D, H, p) OR (D, 1, H*p), etc.

    Returns
    -------
    W_fix        : (D, H, p)   sign/permutation aligned
    v_fix        : (D, H)
    wbar_fix     : (D, H, p)
    summary      : dict with RMSE, Corr, CosSim, SignAgree (means vs means in aligned basis)
    """
    import numpy as np

    W_all = np.asarray(W_all)
    v_all = np.asarray(v_all)
    w_bar_stack = np.asarray(w_bar_stack)

    D = W_all.shape[0]

    # --- infer H from v (source of truth) ---
    v0 = np.squeeze(v_all[0]).ravel()
    H = v0.size
    if H == 0:
        raise ValueError("v_all[0] seems empty; cannot infer H.")
    # infer p from w_bar_stack length
    wb0 = np.squeeze(w_bar_stack[0]).ravel()
    if wb0.size % H != 0:
        # fallback: try infer p from W_all[0] after squeezing
        W0 = np.squeeze(W_all[0])
        if W0.ndim != 2:
            # try to drop any singleton dims
            W0 = W0.reshape([s for s in W0.shape if s != 1])
        if W0.ndim != 2:
            raise ValueError(f"Cannot infer (H,p). v length={H}, but w_bar_stack[0] has {wb0.size} elems "
                             f"and W_all[0] has shape {np.squeeze(W_all[0]).shape}.")
        h, p_candidate = W0.shape
        if h != H and p_candidate == H:
            p = h
        else:
            p = p_candidate
    else:
        p = wb0.size // H

    N = H * p

    # alloc outputs
    W_fix = np.empty((D, H, p), dtype=float)
    v_fix = np.empty((D, H), dtype=float)
    wbar_fix = np.empty((D, H, p), dtype=float)

    def coerce_W(Wd, H, p):
        """Return Wd as (H,p). Accepts (H,p), (p,H), or with singleton dims."""
        A = np.asarray(Wd, dtype=float)
        A = np.squeeze(A)
        if A.ndim == 2:
            h, q = A.shape
            if h == H and q == p:
                return A
            if h == p and q == H:
                return A.T
            # If one matches H, try reshape to (H, -1)
            if h == H and h*q == H*p:
                return A.reshape(H, p)
            if q == H and h*q == H*p:
                return A.T.reshape(H, p)
            raise ValueError(f"Cannot coerce W of shape {A.shape} to (H,p)=({H},{p}).")
        elif A.ndim == 3 and 1 in A.shape:
            # squeeze singleton and recurse
            return coerce_W(np.squeeze(A), H, p)
        else:
            raise ValueError(f"Unexpected W ndim={A.ndim}, shape={A.shape}")

    def coerce_v(vd, H):
        """Return vd as (H,)"""
        v = np.asarray(vd, dtype=float).squeeze().ravel()
        if v.size != H:
            raise ValueError(f"v has size {v.size}, expected H={H}.")
        return v

    def coerce_wbar_row(wbd, H, p):
        """Return wbar row as (H,p) from (N,) or already (H,p)."""
        w = np.asarray(wbd, dtype=float).squeeze().ravel()
        if w.size == H * p:
            return w.reshape(H, p)
        # already 2D?
        W2 = np.asarray(wbd, dtype=float).squeeze()
        if W2.ndim == 2 and W2.shape == (H, p):
            return W2
        raise ValueError(f"w_bar row has {w.size} elems but H*p={H*p} and not (H,p).")

    for d in range(D):
        # coerce shapes
        Wd = coerce_W(W_all[d], H, p)          # (H,p)
        vd = coerce_v(v_all[d], H)             # (H,)
        wbd = coerce_wbar_row(w_bar_stack[d], H, p)

        # 1) sign fix so v >= 0
        s = np.sign(vd)
        s[s == 0.0] = 1.0
        Wd = Wd * s[:, None]
        wbd = wbd * s[:, None]
        vd = np.abs(vd)

        # 2) permute units by a stable key
        if sort_key == "abs_v":
            idx = np.argsort(-vd)  # descending |v|
        elif sort_key == "abs_v_times_rownorm":
            idx = np.argsort(-(vd * np.linalg.norm(Wd, axis=1)))
        else:
            raise ValueError(f"Unknown sort_key: {sort_key}")

        W_fix[d] = Wd[idx]
        wbar_fix[d] = wbd[idx]
        v_fix[d] = vd[idx]

    # Compare means in aligned basis
    w_post_mean = W_fix.reshape(D, -1).mean(axis=0)   # (N,)
    w_lin_mean  = wbar_fix.reshape(D, -1).mean(axis=0)

    rmse = float(np.sqrt(np.mean((w_lin_mean - w_post_mean)**2)))
    corr = float(np.corrcoef(w_lin_mean, w_post_mean)[0, 1])
    cos  = float(np.dot(w_lin_mean, w_post_mean) /
                 (np.linalg.norm(w_lin_mean) * np.linalg.norm(w_post_mean)))
    sign_agree = float(np.mean(np.sign(w_lin_mean) == np.sign(w_post_mean)))

    summary = dict(RMSE=rmse, Corr=corr, CosSim=cos, SignAgree=sign_agree,
                   H=H, p=p, N=N)
    return W_fix, v_fix, wbar_fix, summary

active_cols   = np.arange(0, 5)     # adjust if your actives are different
inactive_cols = np.arange(5, 10)
def block_stats(A, B, cols):
    a = A[:, cols].ravel(); b = B[:, cols].ravel()
    return np.corrcoef(a, b)[0,1], np.sqrt(np.mean((a-b)**2))

W_fix, v_fix, wbar_fix, summary = align_and_compare(W_all_gauss, v_all_gauss, w_bar_stack_gauss, sort_key="abs_v")
print("Gaussian: \n", summary)
w_post_mean = W_fix.reshape(W_fix.shape[0], -1).mean(axis=0)
w_lin_mean  = wbar_fix.reshape(wbar_fix.shape[0], -1).mean(axis=0)
nrmse = np.linalg.norm(w_lin_mean - w_post_mean) / np.linalg.norm(w_post_mean)
r2 = 1 - np.sum((w_lin_mean - w_post_mean)**2) / np.sum((w_post_mean - w_post_mean.mean())**2)
print(f"nRMSE: {nrmse:.3f}, R^2: {r2:.3f} \n")
W_fix, v_fix, wbar_fix, summary = align_and_compare(W_all_RHS, v_all_RHS, w_bar_stack_RHS, sort_key="abs_v")

H, p = summary["H"], summary["p"]
lin = wbar_fix.mean(axis=0).reshape(H, p)
post = W_fix.mean(axis=0).reshape(H, p)

print("Active  -> Corr, RMSE:", block_stats(lin, post, active_cols))
print("Inactive-> Corr, RMSE:", block_stats(lin, post, inactive_cols), "\n")

print("RHS: \n", summary)
w_post_mean = W_fix.reshape(W_fix.shape[0], -1).mean(axis=0)
w_lin_mean  = wbar_fix.reshape(wbar_fix.shape[0], -1).mean(axis=0)
nrmse = np.linalg.norm(w_lin_mean - w_post_mean) / np.linalg.norm(w_post_mean)
r2 = 1 - np.sum((w_lin_mean - w_post_mean)**2) / np.sum((w_post_mean - w_post_mean.mean())**2)
print(f"nRMSE: {nrmse:.3f}, R^2: {r2:.3f} \n")
W_fix, v_fix, wbar_fix, summary = align_and_compare(W_all_DHS, v_all_DHS, w_bar_stack_DHS, sort_key="abs_v")

H, p = summary["H"], summary["p"]
lin = wbar_fix.mean(axis=0).reshape(H, p)
post = W_fix.mean(axis=0).reshape(H, p)

print("Active  -> Corr, RMSE:", block_stats(lin, post, active_cols))
print("Inactive-> Corr, RMSE:", block_stats(lin, post, inactive_cols), "\n")

print("DHS: \n", summary)
w_post_mean = W_fix.reshape(W_fix.shape[0], -1).mean(axis=0)
w_lin_mean  = wbar_fix.reshape(wbar_fix.shape[0], -1).mean(axis=0)
nrmse = np.linalg.norm(w_lin_mean - w_post_mean) / np.linalg.norm(w_post_mean)
r2 = 1 - np.sum((w_lin_mean - w_post_mean)**2) / np.sum((w_post_mean - w_post_mean.mean())**2)
print(f"nRMSE: {nrmse:.3f}, R^2: {r2:.3f} \n")
W_fix, v_fix, wbar_fix, summary = align_and_compare(W_all_DST, v_all_DST, w_bar_stack_DST, sort_key="abs_v")

H, p = summary["H"], summary["p"]
lin = wbar_fix.mean(axis=0).reshape(H, p)
post = W_fix.mean(axis=0).reshape(H, p)

print("Active  -> Corr, RMSE:", block_stats(lin, post, active_cols))
print("Inactive-> Corr, RMSE:", block_stats(lin, post, inactive_cols), "\n")

print("DST: \n", summary)
w_post_mean = W_fix.reshape(W_fix.shape[0], -1).mean(axis=0)
w_lin_mean  = wbar_fix.reshape(wbar_fix.shape[0], -1).mean(axis=0)
nrmse = np.linalg.norm(w_lin_mean - w_post_mean) / np.linalg.norm(w_post_mean)
r2 = 1 - np.sum((w_lin_mean - w_post_mean)**2) / np.sum((w_post_mean - w_post_mean.mean())**2)
print(f"nRMSE: {nrmse:.3f}, R^2: {r2:.3f} \n")

H, p = summary["H"], summary["p"]
lin = wbar_fix.mean(axis=0).reshape(H, p)
post = W_fix.mean(axis=0).reshape(H, p)

print("Active  -> Corr, RMSE:", block_stats(lin, post, active_cols))
print("Inactive-> Corr, RMSE:", block_stats(lin, post, inactive_cols), "\n")



In [None]:
# --- Helper: pick a "MAP-like" representative draw and plot MAP vs. \bar{w} ---
import numpy as np
import matplotlib.pyplot as plt

def select_map_like_index(W_fix: np.ndarray) -> int:
    """
    Returns the index of the draw whose aligned W is closest (in Frobenius norm)
    to the aligned posterior mean -- a robust MAP/medoid proxy.
    W_fix: (D, H, p) aligned weights (output of align_and_compare)
    """
    D = W_fix.shape[0]
    mu = W_fix.reshape(D, -1).mean(axis=0)  # posterior mean in aligned basis
    diffs = W_fix.reshape(D, -1) - mu[None, :]
    d2 = np.einsum('di,di->d', diffs, diffs)  # squared distances
    return int(np.argmin(d2))

def plot_map_vs_barw(W_fix: np.ndarray, wbar_fix: np.ndarray, title: str = "", alpha=0.7):
    """
    Overlay scatter: MAP-like draw's W (dots) vs the same draw's \bar{w} (crosses).
    Both arrays must be aligned: (D, H, p). We auto-pick a representative draw.
    """
    D, H, p = W_fix.shape
    idx = select_map_like_index(W_fix)  # representative draw
    w_map = W_fix[idx].reshape(-1)
    w_bar = wbar_fix[idx].reshape(-1)
    
    eps = 1e-1                          # Small threshold to see non-zero weights

    x = np.arange(1, H*p + 1)
    plt.figure(figsize=(10, 3.5), dpi=150)
    plt.scatter(x, w_map, s=12, marker='o', label="MAP-like $w$", alpha=alpha)
    plt.scatter(x, w_bar, s=18, marker='x', label=r"Linearized $\bar{w}$", alpha=alpha)

    # light vertical guides between hidden units
    for h in range(1, H):
        plt.axvline(h*p + 0.5, color='0.85', lw=1, zorder=0)
    
    plt.axhline(eps, color='0.85', lw=1, zorder=0)
    plt.axhline(-eps, color='0.85', lw=1, zorder=0)

    plt.xlabel("parameter index (after alignment)")
    plt.ylabel("value")
    plt.title(title if title else "MAP-like $w$ vs linearized $\~w$")
    plt.legend()
    plt.tight_layout()
    plt.show()


In [None]:
# --- Gaussian: align and plot ---
W_fix_g, v_fix_g, wbar_fix_g, summary_g = align_and_compare(W_all_gauss, v_all_gauss, w_bar_stack_gauss, sort_key="abs_v")
print("Gaussian summary:", summary_g)
plot_map_vs_barw(W_fix_g, wbar_fix_g, title="Gaussian prior: MAP-like $w$ vs linearized $\\bar{w}$")


In [None]:
# --- Regularized Horseshoe: align and plot ---
W_fix_r, v_fix_r, wbar_fix_r, summary_r = align_and_compare(W_all_RHS, v_all_RHS, w_bar_stack_RHS, sort_key="abs_v")
print("RHS summary:", summary_r)
plot_map_vs_barw(W_fix_r, wbar_fix_r, title="RHS prior: MAP-like $w$ vs linearized $\\bar{w}$")


In [None]:
# --- Dirichlet Horseshoe & Dirichlet Student-t: align and plot ---
W_fix_dhs, v_fix_dhs, wbar_fix_dhs, summary_dhs = align_and_compare(W_all_DHS, v_all_DHS, w_bar_stack_DHS, sort_key="abs_v")
print("DHS summary:", summary_dhs)
plot_map_vs_barw(W_fix_dhs, wbar_fix_dhs, title="DHS prior: MAP-like $w$ vs linearized $\\bar{w}$")


In [None]:

W_fix_dst, v_fix_dst, wbar_fix_dst, summary_dst = align_and_compare(W_all_DST, v_all_DST, w_bar_stack_DST, sort_key="abs_v")
print("DST summary:", summary_dst)
plot_map_vs_barw(W_fix_dst, wbar_fix_dst, title="DST prior: MAP-like $w$ vs linearized $\\bar{w}$")


## Forskjell mellom lambda_eff og lambda:

In [None]:
lambda_samples = prior_N100_fits['Regularized Horseshoe']['posterior'].stan_variable("lambda")[1].flatten()
reg_lambda_samples = prior_N100_fits['Regularized Horseshoe']['posterior'].stan_variable("lambda_tilde")[1].flatten()
tau_samples = prior_N100_fits['Regularized Horseshoe']['posterior'].stan_variable("tau")[1]

import numpy as np

P = P_RHS[1]
S = S_RHS[1]

# Gitt: P (diagonal positiv), S (symmetrisk)
p = np.diag(P)                       # diagonalene i P
P_inv_sqrt = np.diag(1.0/np.sqrt(p))          # P^{-1/2}
W = P_inv_sqrt @ S @ P_inv_sqrt                        # whitened

# Symmetrisk EVD
r, U = np.linalg.eigh(W)             # r = egenverdier (stigende), U kolonner = egenvektorer

inv_lambda2 = 1.0 / (lambda_samples**2)
# λ_eff_i^2 = 1 / sum_j (U_{ji}^2 / λ_j^2)
lambda_eff_sq = 1.0 / (U**2 @ inv_lambda2)
lambda_eff = np.sqrt(lambda_eff_sq)  # shape (160,)

print(np.median(lambda_eff), np.median(reg_lambda_samples))
lambda_eff.shape, lambda_eff[:5]

import numpy as np

lam_eff = lambda_eff  # your 160 values
s_hat = np.median(lam_eff)

# QQ points
p = (np.arange(1, len(lam_eff)+1) - 0.5) / len(lam_eff)
q_theory = s_hat * np.tan(0.5*np.pi*p)
q_emp = np.sort(lam_eff)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import halfcauchy

# --- QQ plot function ---
def qq_plot(data, label, color):
    n = len(data)
    p = (np.arange(1, n+1) - 0.5) / n
    q_theory = halfcauchy.ppf(p, scale=1)  # Half-Cauchy(0,1)
    q_emp = np.sort(data)
    plt.scatter(q_theory, q_emp, label=label, alpha=0.7, color=color)

# --- Tail plot function ---
def tail_plot(data, label, color):
    sorted_data = np.sort(data)
    n = len(data)
    surv_emp = np.arange(n, 0, -1) / n  # empirical survival
    plt.plot(sorted_data, sorted_data * surv_emp, label=label, color=color)

# -----------------------------
# QQ plot
plt.figure(figsize=(6,6))
qq_plot(lambda_eff, "lambda_eff", "C0")
qq_plot(reg_lambda_samples, "lambda_draws", "C1")
lims = [0, max(np.max(lambda_eff), np.max(reg_lambda_samples), 10)]
plt.plot(lims, lims, 'k--', lw=1, label="y=x")
plt.xlabel("Theoretical Half-Cauchy(0,1) quantiles")
plt.ylabel("Empirical quantiles")
plt.title("QQ-plot vs Half-Cauchy(0,1)")
plt.legend()
plt.grid(True)
plt.show()

# -----------------------------
# Tail plot
x = np.linspace(0.1, 10, 200)
surv_theory = 1 - halfcauchy.cdf(x, scale=1)

plt.figure(figsize=(7,5))
tail_plot(lambda_eff, "lambda_eff", "C0")
tail_plot(reg_lambda_samples, "lambda_draws", "C1")
plt.plot(x, x * surv_theory, 'k--', lw=1, label="Half-Cauchy(0,1)")
plt.xlabel("x")
plt.ylabel("x * Survival(x)")
plt.title("Tail diagnostic: x * P(X>x)")
plt.legend()
plt.grid(True)
plt.show()
