In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
#from sklearn.metrics import mean_squared_errosr
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
results_dir_priors = "results/priors/single_layer/tanh/friedman"
results_dir_posteriors = "results/regression/single_layer/tanh/friedman"

prior_names = ["Dirichlet Horseshoe", "Regularized Horseshoe", "Dirichlet Student T", "Gaussian"]
posterior_names = ["Dirichlet Horseshoe tanh", "Regularized Horseshoe tanh", "Dirichlet Student T tanh", "Gaussian tanh"]


prior_N100_fits = get_model_fits(
    config="Friedman_N100_p10_sigma1.00_seed1",
    results_dir=results_dir_priors,
    models=prior_names,
    include_prior=False,
)

prior_N200_fits = get_model_fits(
    config="Friedman_N200_p10_sigma1.00_seed2",
    results_dir=results_dir_priors,
    models=prior_names,
    include_prior=False,
)

prior_N500_fits = get_model_fits(
    config="Friedman_N500_p10_sigma1.00_seed11",
    results_dir=results_dir_priors,
    models=prior_names,
    include_prior=False,
)
    
posterior_N100_fits = get_model_fits(
    config="Friedman_N100_p10_sigma1.00_seed1",
    results_dir=results_dir_posteriors,
    models=posterior_names,
    include_prior=False,
)

posterior_N200_fits = get_model_fits(
    config="Friedman_N200_p10_sigma1.00_seed2",
    results_dir=results_dir_posteriors,
    models=posterior_names,
    include_prior=False,
)

posterior_N500_fits = get_model_fits(
    config="Friedman_N500_p10_sigma1.00_seed11",
    results_dir=results_dir_posteriors,
    models=posterior_names,
    include_prior=False,
)


In [3]:
path = "datasets/friedman/Friedman_N500_p10_sigma1.00_seed11.npz"
data = np.load(path)
X = data['X_train']
y = data['y_train']

In [4]:
import numpy as np
from typing import Tuple, Callable

# ---------- Aktivasjon og deriverte ----------

def get_activation(activation: str = "tanh") -> Tuple[Callable, Callable]:
    if activation == "tanh":
        phi = np.tanh
        def dphi(a): return 1.0 - np.tanh(a)**2
    elif activation == "relu":
        def phi(a): return np.maximum(0.0, a)
        def dphi(a): return (a > 0.0).astype(a.dtype)
    else:
        raise ValueError(f"Unsupported activation: {activation}")
    return phi, dphi

# ---------- H(w0) og J_W(w0, v0) ----------

def build_hidden_and_jacobian_W(
    X: np.ndarray,               # (n, p)
    W0: np.ndarray,              # (H, p)  -- vekter i referansepunktet w0
    b0: np.ndarray,              # (H,)    -- bias i referansepunktet w0
    v0: np.ndarray,              # (H,)    -- utgangsvekter i referansepunktet v0
    activation: str = "tanh",
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Returnerer:
      H  : (n, H)         = H(w0)
      JW : (n, H*p)       = d(H(w)v)/d vec(W) |_(w0, v0), kolonner ordnet som (h=0..H-1, j=0..p-1)
    """
    n, p = X.shape
    H, pW = W0.shape
    assert pW == p
    phi, dphi = get_activation(activation)

    # Pre- og post-aktivert
    A = X @ W0.T + b0[None, :]        # (n, H), a_{i,h}
    Hmat = phi(A)                     # (n, H), h_{i,h}
    dphiA = dphi(A)                   # (n, H)

    # J_W: df/dW_{h,j} = v_h * dphi(a_{i,h}) * x_{i,j}
    # For hver node h bygger vi et (n, p)-bidrag og flater ut langs j, og stabler så langs h.
    JW_blocks = []
    for h in range(H):
        # (n,1) * (1,p) -> (n,p)
        block_h = (v0[h] * dphiA[:, [h]]) * X 
        JW_blocks.append(block_h.reshape(n, p))
    # Stack kolonnevis i rekkefølge (h, j) -> (n, H*p)
    Jb = dphiA * v0[None, :]     # (n, H)
    JW = np.hstack([B for B in JW_blocks])
    return Hmat, JW, Jb

# ---------- Sigma_y og P ----------

def build_Sigma_y(
    Hmat: np.ndarray,     # (n, H) = H(w0)
    tau_v: float,         # prior std for v
    J_b: np.ndarray,      # Jacobian for bias
    sigma: float          # støy std i likelihood
) -> np.ndarray:
    """
    Σ_y = J_bJ_b^T + τ_v^2 H H^T + σ^2 I_n
    """
    n = Hmat.shape[0]
    return (tau_v**2) * (Hmat @ Hmat.T) + (J_b@J_b.T) + (sigma**2) * np.eye(n)

def build_P_from_lambda_tau(
    lambda_tilde: np.ndarray,  # (H, p) lokale skalaer for W
    tau_w: float               # global skala for w
) -> np.ndarray:
    """
    P = τ_w^{-2} Λ^{-1} der Λ = diag(λ^2) for konsistens med uttrykket 1/(1 + τ^2 λ^2 s).
    Dvs. diag(P) = 1 / (τ_w^2 * λ^2).
    Returnerer P som (H*p, H*p) diagonalmatrise.
    """
    lam_vec = lambda_tilde.reshape(-1)          # (H*p,)
    diagP = 1.0 / ( (tau_w**2) * (lam_vec) ) # (H*p,)
    return np.diag(diagP)

# ---------- S, shrinkage-matrise R = (P+S)^{-1} P ----------

def build_S(JW: np.ndarray, Sigma_y: np.ndarray) -> np.ndarray:
    """
    S = J_W^T Σ_y^{-1} J_W  (Hp x Hp).
    Løser via lineær solve for stabilitet: X = Σ_y^{-1} J_W = solve(Σ_y, J_W).
    """
    X = np.linalg.solve(Sigma_y, JW)       # (n, Hp)
    return JW.T @ X                        # (Hp, Hp)

def shrinkage_matrix(P: np.ndarray, S: np.ndarray) -> np.ndarray:
    """
    R = (P+S)^{-1} P. Bruk Cholesky når mulig.
    Løser (P+S) * R = P for R.
    """
    A = P + S
    # Robust fallback hvis Cholesky feiler
    try:
        L = np.linalg.cholesky(A)
        # L Y = P  -> Y
        Y = np.linalg.solve(L, P)
        # L^T R = Y -> R
        R = np.linalg.solve(L.T, Y)
    except np.linalg.LinAlgError:
        R = np.linalg.solve(A, P)
    return R

def shrinkage_matrix_stable(P, S, jitter=0.0):
    """
    Stabil beregning av R = (P+S)^{-1} P via
    R = P^{1/2} (I + P^{-1/2} S P^{-1/2})^{-1} P^{1/2}.
    Krever at P er diagonal (positiv).
    """
    d = np.diag(P).astype(float)
    # Guardrails: ingen nuller/NaN/negativ
    eps = 1e-12
    d = np.clip(d, eps, np.finfo(float).max)
    Phalf    = np.diag(np.sqrt(d))
    Pinvhalf = np.diag(1.0 / np.sqrt(d))

    M = Pinvhalf @ S @ Pinvhalf
    # Jitter for SPD-sikkerhet (skader ikke i praksis)
    if jitter > 0:
        M = M + jitter * np.eye(M.shape[0])

    # (I + M) er SPD -> Cholesky
    I = np.eye(M.shape[0])
    L = np.linalg.cholesky(I + M)
    # (I+M)^{-1} P^{1/2} = (L^T)^{-1} (L)^{-1} P^{1/2}
    Z = np.linalg.solve(L, Phalf)
    W = np.linalg.solve(L.T, Z)
    # R = P^{1/2} * W
    R = Phalf @ W
    # Symmetrer (numerisk)
    R = 0.5 * (R + R.T)
    return R

def shrinkage_eigs_and_df(P, S):
    """Returner r-eigenverdier og df_eff i P-whitnede koordinater."""
    d = np.diag(P).astype(float)
    eps = 1e-12
    Pinvhalf = np.diag(1.0 / np.sqrt(np.maximum(d, eps)))

    M = Pinvhalf @ S @ Pinvhalf          # SPD
    mu = np.linalg.eigvalsh(M)           # >= 0
    r = 1.0 / (1.0 + mu)                 # i (0,1]
    df_eff = np.sum(1.0 - r)             # = sum mu/(1+mu) >= 0
    return r, df_eff

def extract_model_draws(fit_dict, model: str):
    """
    Returnerer ALL draws, med 'lambda_all' definert som EFFEKTIV VARIANSFAKTOR per vekt:
      Gaussian:                 lambda_all = 1
      Regularized Horseshoe:    lambda_all = lambda_tilde
      Dirichlet (DHS/DST):      lambda_all = lambda_tilde_data * phi_data

    Shapes:
      W_all      : (D, H, p)
      b_all      : (D, H)
      v_all      : (D, H)
      c_all      : (D,)
      sigma_all  : (D,)
      tau_w_all  : (D,)
      tau_v_all  : (D,)   (ones if not in fit)
      lambda_all : (D, H, p)  <-- effektiv variansfaktor
    """
    post = fit_dict[model]['posterior']

    # W_1: (D, p, H) -> (D, H, p)
    W_1 = np.asarray(post.stan_variable("W_1"))
    W_all = np.transpose(W_1, (0, 2, 1))
    D, H, p = W_all.shape

    # W_L: (D, H, out_nodes=1) -> (D, H)
    W_L = np.asarray(post.stan_variable("W_L"))
    v_all = W_L.reshape(D, -1)

    # hidden_bias: (D, 1, H) -> (D, H)
    b_1 = np.asarray(post.stan_variable("hidden_bias"))
    b_all = b_1.reshape(D, -1)

    # output_bias: (D, 1) -> (D,)
    b_2 = np.asarray(post.stan_variable("output_bias"))
    c_all = b_2.reshape(D)

    # sigma: (D,)
    sigma_all = np.asarray(post.stan_variable("sigma")).reshape(D)

    # Modell-flagg
    is_gauss      = ("Gaussian" in model)
    is_rhs        = ("Regularized Horseshoe" in model)
    is_dirichlet  = ("Dirichlet" in model) or ("DST" in model)

    # tau_w / tau_v
    if is_gauss:
        tau_w_all = np.ones(D)
        tau_v_all = np.ones(D)
    else:
        tau_w_all = np.asarray(post.stan_variable("tau")).reshape(D)
        try:
            tau_v_all = np.asarray(post.stan_variable("tau_v")).reshape(D)
        except Exception:
            tau_v_all = np.ones(D)

    # Effektiv lokal variansfaktor lambda_all
    if is_gauss:
        lambda_all = np.ones((D, H, p))
    else:
        lam_name = "lambda_tilde" if is_rhs else "lambda_tilde_data"
        lam = np.asarray(post.stan_variable(lam_name))
        # Bring til (D, H, p)
        if lam.shape[1:] == (H, p):
            lam_var = lam
        else:
            lam_var = np.transpose(lam, (0, 2, 1))

        if is_dirichlet:
            # phi_data: sannsynlige shapes (D, H, p) eller (D, p, H)
            phi = np.asarray(post.stan_variable("phi_data"))
            if phi.shape[1:] == (H, p):
                phi_hp = phi
            else:
                phi_hp = np.transpose(phi, (0, 2, 1))
            # Stan: stddev = tau * sqrt(lambda_tilde) * sqrt(phi)
            #  => var = tau^2 * lambda_tilde * phi
            lambda_all = lam_var * phi_hp
        else:
            # RHS: var = tau^2 * lambda_tilde
            lambda_all = lam_var

    return W_all, b_all, v_all, c_all, sigma_all, tau_w_all, tau_v_all, lambda_all

# ------- Knyt alt sammen -------

def compute_shrinkage_for_W_block(
    X: np.ndarray,
    W0: np.ndarray, b0: np.ndarray, v0: np.ndarray,
    sigma: float, tau_w: float, tau_v: float,
    lambda_tilde: np.ndarray,
    activation: str = "tanh"
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Returnerer (R, P, S, Sigma_y) der R = (P+S)^{-1} P for W-blokken.
    """
    Hmat, JW, Jb = build_hidden_and_jacobian_W(X, W0, b0, v0, activation=activation)  # (n,H), (n,Hp)
    Sigma_y = build_Sigma_y(Hmat, tau_v=tau_v, J_b=Jb, sigma=sigma)                       # (n,n)
    P = build_P_from_lambda_tau(lambda_tilde, tau_w=tau_w)                        # (Hp,Hp)
    S = build_S(JW, Sigma_y)                                                      # (Hp,Hp)
    R = shrinkage_matrix_stable(P, S)                                                    # (Hp,Hp)
    return R, P, S, Sigma_y, JW, Hmat


def compute_shrinkage(
    X,
    W_all, b_all, v_all,          # (D,H,p), (D,H), (D,H)
    sigma_all, tau_w_all, tau_v_all,  # (D,), (D,), (D,)
    lambda_all,                   # (D,H,p)
    activation="tanh",
    return_mats=True,             # set False if you only want summaries
):
    """
    Loop over draws and compute R=(P+S)^{-1}P per draw using your single-draw function.
    Returns:
      R_stack : (D, N, N) with N=H*p  (if return_mats=True, else None)
      r_eigs  : (D, N)  sorted eigenvalues in [0,1]
      df_eff  : (D,)    effective dof = tr(I-R) = N - tr(R)
    """
    D, H, p = W_all.shape
    N = H * p

    R_stack = np.empty((D, N, N)) if return_mats else None
    S_stack = np.empty((D, N, N)) if return_mats else None
    P_stack = np.empty((D, N, N)) if return_mats else None
    W_stack = np.empty((D, N, N)) if return_mats else None
    shrink_stack= np.empty((D, N, N)) if return_mats else None
    r_eigs  = np.empty((D, N))
    df_eff  = np.empty(D)

    for d in range(D):
        R, P, S, Sigma_y, _, _ = compute_shrinkage_for_W_block(
            X=X,
            W0=W_all[d],
            b0=b_all[d],
            v0=v_all[d],
            sigma=float(sigma_all[d]),
            tau_w=float(tau_w_all[d]),
            tau_v=float(tau_v_all[d]),
            lambda_tilde=lambda_all[d],
            activation=activation,
        )
        p = np.diag(P)                       
        P_inv_sqrt = np.diag(1.0/np.sqrt(p))         
        W = P_inv_sqrt @ S @ P_inv_sqrt 
        I = np.identity(N)
        shrink_mat = np.linalg.inv(I + W)@W

        if return_mats:
            R_stack[d] = R
            S_stack[d] = S
            P_stack[d] = P
            W_stack[d] = W
            shrink_stack[d] = shrink_mat
        


        r, df = shrinkage_eigs_and_df(P, S)
        r_eigs[d] = np.sort(r)
        df_eff[d] = df

    return R_stack, S_stack, P_stack, W_stack, shrink_stack, r_eigs, df_eff


# ---------- Minimal kjøreeksempel ----------
#draw = 0
W, b1, v, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Gaussian tanh')

R_gauss, S_gauss, P_gauss, W_gauss, shrink_gauss, eigs_gauss, df_eff_gauss = compute_shrinkage(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")

#R_gauss, P_gauss, S_gauss, Sigma_y_gauss = compute_shrinkage_for_W_block(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")

W, b1, v, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Regularized Horseshoe tanh')
R_RHS, S_RHS, P_RHS, W_RHS, shrink_RHS, eigs_RHS, df_eff_RHS = compute_shrinkage(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")
# R_RHS, P_RHS, S_RHS, Sigma_y_RHS = compute_shrinkage_for_W_block(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")

W, b1, v, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Dirichlet Horseshoe tanh')
R_DHS, S_DHS, P_DHS, W_DHS, shrink_DHS, eigs_DHS, df_eff_DHS = compute_shrinkage(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")
# R_DHS, P_DHS, S_DHS, Sigma_y_DHS = compute_shrinkage_for_W_block(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")

W, b1, v, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Dirichlet Student T tanh')
R_DST, S_DST, P_DST, W_DST, shrink_DST, eigs_DST, df_eff_DST = compute_shrinkage(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")
# R_DST, P_DST, S_DST, Sigma_y_DST = compute_shrinkage_for_W_block(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")


In [None]:

import matplotlib.pyplot as plt
from matplotlib.colors import SymLogNorm, Normalize, TwoSlopeNorm
import numpy as np


def add_block_grid(ax, H, p, color="w", lw=0.5):
    Hp = H*p
    for h in range(1, H):
        k = h*p
        ax.axhline(k-0.5, color=color, lw=lw)
        ax.axvline(k-0.5, color=color, lw=lw)

def visualize_models(
    matrices, names, H=16, p=10, use_abs=False, cmap="magma",
    q_low=0.05, q_high=0.95
):
    """
    Viser heatmaps av matriser med felles, robust fargeskala:
    vmin = q_low-kvantilen over ALLE matrisene
    vmax = q_high-kvantilen over ALLE matrisene
    Verdier utenfor [vmin, vmax] klippes til endene.
    """
    # valgfritt absoluttbeløp
    mats = [np.abs(M) if use_abs else M for M in matrices]

    # Samle alle endelige verdier
    all_vals = np.concatenate([M[np.isfinite(M)].ravel() for M in mats]) if mats else np.array([])

    if all_vals.size == 0:
        vmin, vmax = -1.0, 1.0
    else:
        vmin = float(np.quantile(all_vals, q_low))
        vmax = float(np.quantile(all_vals, q_high))
        if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
            # fallback hvis alt er likt/NaN
            m = float(np.nanmean(all_vals)) if np.isfinite(np.nanmean(all_vals)) else 0.0
            vmin, vmax = m - 1.0, m + 1.0

    norm = Normalize(vmin=vmin, vmax=vmax)
    #norm = TwoSlopeNorm(vmin=vmin, vcenter=0.0, vmax=vmax)

    # Figur
    fig, axes = plt.subplots(2, 2, figsize=(10, 10), dpi=150, constrained_layout=True)
    axes = axes.ravel()

    im = None
    for ax, M, title in zip(axes, mats, names):
        Mplot = np.clip(M, vmin, vmax)  # klipp utfor intervallet
        im = ax.imshow(Mplot, aspect='equal', interpolation='nearest', cmap=cmap, norm=norm)
        add_block_grid(ax, H, p)
        ax.set_title(title)
        ax.set_xlabel("Columns")
        ax.set_ylabel("Rows")

    # Felles colorbar
    if im is not None:
        fig.colorbar(im, ax=axes.tolist(), label=("|Value|" if use_abs else "Value"))

    # (valgfritt) print hva som faktisk ble brukt
    print(f"Felles kvantiler: vmin (q={q_low}) = {vmin:.6g}, vmax (q={q_high}) = {vmax:.6g}")

    plt.show()

# Example call:
# matrices_R = [
#     R_gauss.mean(axis=0), #[0],
#     R_RHS.mean(axis=0), #[0],
#     R_DHS.mean(axis=0), #[0],
#     R_DST.mean(axis=0), #[0]
# ]

# matrices_S = [
#     S_gauss.mean(axis=0),
#     S_RHS.mean(axis=0),
#     S_DHS.mean(axis=0),
#     S_DST.mean(axis=0)
# ]

# matrices_P_inv = [
#     np.linalg.inv(P_gauss.mean(axis=0)), #[0],
#     np.linalg.inv(P_RHS.mean(axis=0)), #[0],
#     np.linalg.inv(P_DHS.mean(axis=0)), #[0],
#     np.linalg.inv(P_DST.mean(axis=0)), #[0]
# ]

matrices_W = [
    W_gauss.mean(axis=0), #[0],
    W_RHS.mean(axis=0), #[0],
    W_DHS.mean(axis=0), #[0],
    W_DST.mean(axis=0), #[0]
]

# matrices_med = [
#     np.median(W_gauss, axis=0), #[0],
#     np.median(W_RHS, axis=0), #[0],
#     np.median(W_DHS, axis=0), #[0],
#     np.median(W_DST, axis=0), #[0]
# ]

# matrices_shrink = [
#     np.mean(shrink_gauss, axis=0), #[0],
#     np.mean(shrink_RHS, axis=0), #[0],
#     np.mean(shrink_DHS, axis=0), #[0],
#     np.mean(shrink_DST, axis=0), #[0]
# ]


names = [
    "R (Gauss)", #"Data shrinkage (Gauss)", "Prior shrinkage (Gauss)",
    "R (RHS)",   #"Data shrinkage (RHS)",   "Prior shrinkage (RHS)",
    "R (DHS)",   #"Data shrinkage (DHS)",   "Prior shrinkage (DHS)",
    "R (DST)",   #"Data shrinkage (DST)",   "Prior shrinkage (DST)",
]

visualize_models(matrices_W, names, H=16, p=10, use_abs=False)


In [38]:
import numpy as np
import matplotlib.pyplot as plt

# ----------------------------- #
# 1) Spectra for W and (I-R)
# ----------------------------- #
def spectra_from_R_W(W_stack):
    """
    Inputs
      W_stack : (D, N, N), symmetric up to numerical noise

    Returns
      omega   : (D, N)    eigenvalues of W  (>=0)
      r       : (D, N)    shrinkage eigenvalues of (I-R) = (I+W)^{-1} W  in (0,1)
      df      : (D,)      effective degrees of freedom = sum(r)
    """
    # symmetrize for numerical safety
    W_sym = 0.5 * (W_stack + np.swapaxes(W_stack, 1, 2))
    # eigs of W (PSD), clip tiny negatives from roundoff
    omega = np.linalg.eigvalsh(W_sym)
    omega = np.clip(omega, 0.0, None)
    # shrinkage eigs of (I-R): r = omega/(1+omega)  (same eigenvectors as W)
    r = omega / (1.0 + omega)
    df = r.sum(axis=1)
    return omega, r, df

# ----------------------------- #
# 2) Plots for the spectra
# ----------------------------- #
def plot_spectra(omega, r, model_name="", bins=60):
    """
    Pooled views across all draws and all modes.
    """
    om = omega.ravel()
    rr = r.ravel()
    fig, ax = plt.subplots(1, 2, figsize=(12, 4.5))

    # left: histogram of log(1+omega) (handles wide range)
    ax[0].hist(np.log1p(om), bins=bins, density=True)
    ax[0].set_xlabel("log(1 + ω)  (ω eigenvalues of W)")
    ax[0].set_ylabel("Density")
    ax[0].set_title(f"{model_name} — spectrum of W")

    # right: ECDF of r in (0,1)
    rr_sorted = np.sort(rr)
    y = np.linspace(0, 1, rr_sorted.size, endpoint=True)
    ax[1].plot(rr_sorted, y)
    ax[1].set_xlim(0, 1)
    ax[1].set_xlabel("r = ω/(1+ω)  (eigs of (I+W)^{-1}W ≡ (I-R))")
    ax[1].set_ylabel("ECDF")
    ax[1].set_title(f"{model_name} — shrinkage eigenvalues")
    plt.tight_layout()
    plt.show()

def plot_df(df, model_name=""):
    plt.figure(figsize=(6,3))
    plt.hist(df, bins=50, density=True)
    plt.xlabel("df = sum of shrinkage eigenvalues")
    plt.ylabel("Density")
    plt.title(f"{model_name} — effective d.f.")
    plt.tight_layout()
    plt.show()

# ----------------------------- #
# 3) Block strength (10x10 blocks)
# ----------------------------- #
def block_strength_grid(M, H=16, p=10, stat="mean_abs"):
    """
    One matrix -> (H,H) grid of block strengths.
    Blocks are p×p; diagonal blocks correspond to units.
    stat: "mean_abs" (default) | "fro" | "spec"
    """
    B = np.zeros((H, H), dtype=float)
    for i in range(H):
        rs = slice(i*p, (i+1)*p)
        for j in range(H):
            cs = slice(j*p, (j+1)*p)
            blk = M[rs, cs]
            if stat == "mean_abs":
                val = np.mean(np.abs(blk))
            elif stat == "fro":
                val = np.linalg.norm(blk, "fro") / (p*np.sqrt(p))  # size-normalized
            elif stat == "spec":
                # spectral norm; guard tiny blocks
                s = np.linalg.svd(blk, compute_uv=False)
                val = float(s[0]) if s.size else 0.0
            else:
                raise ValueError("stat must be 'mean_abs', 'fro', or 'spec'")
            B[i, j] = val
    return B

def summarize_blocks_many(M_stack, H=16, p=10, stat="mean_abs"):
    """
    Stack of matrices -> mean grid (H,H), and diag block stats across draws.
    Returns:
      mean_grid : (H,H)  mean block strengths across draws
      diag_mu   : (H,)   mean of diagonal block strengths across draws
      diag_lo   : (H,)   5th percentile per diagonal block
      diag_hi   : (H,)   95th percentile per diagonal block
    """
    D, N, _ = M_stack.shape
    grids = np.empty((D, H, H))
    diag_vals = np.empty((D, H))
    for d in range(D):
        G = block_strength_grid(M_stack[d], H=H, p=p, stat=stat)
        grids[d] = G
        diag_vals[d] = np.diag(G)
    mean_grid = grids.mean(axis=0)
    diag_mu = diag_vals.mean(axis=0)
    diag_lo = np.quantile(diag_vals, 0.05, axis=0)
    diag_hi = np.quantile(diag_vals, 0.95, axis=0)
    return mean_grid, diag_mu, diag_lo, diag_hi

def plot_block_heatmaps(mean_grid_W, mean_grid_shrink, stat="mean_abs", model_name=""):
    vmin = min(mean_grid_W.min(), mean_grid_shrink.min())
    vmax = max(mean_grid_W.max(), mean_grid_shrink.max())
    fig, ax = plt.subplots(1, 2, figsize=(11, 4.8))
    im0 = ax[0].imshow(mean_grid_W, vmin=vmin, vmax=vmax, origin="lower", aspect="equal")
    ax[0].set_title(f"{model_name}: block strength of W  ({stat})")
    ax[0].set_xlabel("block j"); ax[0].set_ylabel("block i")
    im1 = ax[1].imshow(mean_grid_shrink, vmin=vmin, vmax=vmax, origin="lower", aspect="equal")
    ax[1].set_title(f"{model_name}: block strength of (I-R)  ({stat})")
    ax[1].set_xlabel("block j"); ax[1].set_ylabel("block i")
    cbar = fig.colorbar(im1, ax=ax.ravel().tolist(), shrink=0.9)
    cbar.set_label("block strength")
    plt.tight_layout()
    plt.show()

def plot_diag_blocks(diag_mu_W, diag_lo_W, diag_hi_W,
                     diag_mu_S, diag_lo_S, diag_hi_S,
                     model_name=""):
    """
    Diagonal (unit-wise) block strength: mean with 5–95% bands for W and (I-R).
    """
    H = diag_mu_W.size
    x = np.arange(H)
    plt.figure(figsize=(12,4))
    # W
    plt.plot(x, diag_mu_W, label="W (mean)", lw=2)
    plt.fill_between(x, diag_lo_W, diag_hi_W, alpha=0.15)
    # shrinkage
    plt.plot(x, diag_mu_S, label="(I-R) (mean)", lw=2)
    plt.fill_between(x, diag_lo_S, diag_hi_S, alpha=0.15)
    plt.xlabel("Hidden unit (block index)")
    plt.ylabel("Diagonal block strength")
    plt.title(f"{model_name}: diagonal block strengths (mean ± 5–95%)")
    plt.legend()
    plt.tight_layout()
    plt.show()


In [None]:
H, p = 16, 10

# Spectra
omega, r, df = spectra_from_R_W(W_gauss)
plot_spectra(omega, r, model_name="Your model")
plot_df(df, model_name="Your model")

# Block strengths
# (I-R) is the global shrinkage operator; compute it directly
I_minus_R_stack = np.eye(H*p)[None, :, :] - R_gauss

# Mean block heatmaps (same color scale)
meanW, dmuW, dloW, dhiW = summarize_blocks_many(W_gauss, H=H, p=p, stat="mean_abs")
meanS, dmuS, dloS, dhiS = summarize_blocks_many(I_minus_R_stack, H=H, p=p, stat="mean_abs")

plot_block_heatmaps(meanW, meanS, stat="mean_abs", model_name="Your model")
plot_diag_blocks(dmuW, dloW, dhiW, dmuS, dloS, dhiS, model_name="Your model")


In [None]:
w, U = np.linalg.eigh(W_DST.mean(axis=0))


shrink = 1/(1+w)

print(shrink)

In [None]:
w, U = np.linalg.eigh(W_RHS.mean(axis=0))


shrink = 1/(1+w)

print(shrink)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# ------- summaries -------
def summarize_df(df):
    df = np.asarray(df).reshape(-1)
    return {
        "mean": float(np.mean(df)),
        "sd":   float(np.std(df, ddof=1)),
        "q10":  float(np.quantile(df, 0.10)),
        "q50":  float(np.quantile(df, 0.50)),
        "q90":  float(np.quantile(df, 0.90)),
    }

def summarize_eigs(eigs, small=0.1, large=0.9):
    # eigs: (D,N) of r \in [0,1]
    r = np.asarray(eigs).reshape(-1)
    return {
        "median_r": float(np.median(r)),
        "q10_r":    float(np.quantile(r, 0.10)),
        "q90_r":    float(np.quantile(r, 0.90)),
        "frac_small(<{:.2f})".format(small): float(np.mean(r < small)),
        "frac_large(>{:.2f})".format(large): float(np.mean(r > large)),
    }

# ------- ECDF plotting helpers -------

def plot_hist_eigs(model_eigs_dict, bins=50, title=""):
    """
    model_eigs_dict: {"Name": eigs}, where eigs has shape (D, N) or (D*N,)
                     and contains r-eigenvalues in [0, 1].
    """
    plt.figure(figsize=(9,5))

    # common bins across models for fair comparison
    if isinstance(bins, int):
        bins = np.linspace(0.0, 1.0, bins+1)

    for name, eigs in model_eigs_dict.items():
        x = np.asarray(eigs).ravel()
        x = np.clip(x, 0.0, 1.0)               # safety
        plt.hist(x, bins=bins, density=True, alpha=0.5, label=name)

    plt.xlabel("r eigenvalues (0 = no shrink, 1 = hard shrink)")
    plt.ylabel("Density")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()


# ------- optional: rho transform (prior/data ratio) -------
def plot_hist_log1p_rho(model_eigs_dict, bins=50, eps=1e-12, title=""):
    """
    model_eigs_dict: {"Name": eigs}, eigs shape (D,N) or (D*N,)
                     r-eigenvalues in [0,1].
    Plots overlaid histograms of log(1 + rho), rho = r/(1-r).
    """
    # flatten & transform for pooled bin edges
    pooled = []
    for eigs in model_eigs_dict.values():
        r = np.asarray(eigs).ravel()
        r = np.clip(r, 0.0, 1.0 - eps)        # avoid division by zero at r=1
        rho = r / (1.0 - r)
        pooled.append(np.log1p(rho))
    pooled = np.concatenate(pooled)

    # common bins (use pooled range, cap extreme tail)
    if isinstance(bins, int):
        hi = float(np.quantile(pooled, 0.995))  # ignore extreme 0.5% tail
        bins = np.linspace(0.0, hi, bins + 1)

    plt.figure(figsize=(9,5))
    for name, eigs in model_eigs_dict.items():
        r = np.asarray(eigs).ravel()
        r = np.clip(r, 0.0, 1.0 - eps)
        rho = r / (1.0 - r)
        x = np.log1p(rho)
        plt.hist(x, bins=bins, density=True, alpha=0.5, label=name)

    plt.xlabel("log(1 + rho),  rho = r/(1-r)")
    plt.ylabel("Density")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()


# ------- run summaries -------
models_eigs = {
    "Gaussian": eigs_gauss,
    "RHS":      eigs_RHS,
    "DHS":      eigs_DHS,
    "DST":      eigs_DST,
}
models_df = {
    "Gaussian": df_eff_gauss,
    "RHS":      df_eff_RHS,
    "DHS":      df_eff_DHS,
    "DST":      df_eff_DST,
}

print("df_eff summaries:")
for name, df in models_df.items():
    print(name, summarize_df(df))

print("\nEigenvalue summaries:")
for name, eigs in models_eigs.items():
    print(name, summarize_eigs(eigs, small=0.1, large=0.9))

# Plots to visualize differences in shrinkage
plot_hist_eigs(models_eigs, title="Histogram of shrinkage eigenvalues r")
plot_hist_log1p_rho(models_eigs, title="Histogram of log(1+rho)")


## Forskjell mellom lambda_eff og lambda:

In [70]:
lambda_samples = prior_N100_fits['Regularized Horseshoe']['posterior'].stan_variable("lambda")[1].flatten()
reg_lambda_samples = prior_N100_fits['Regularized Horseshoe']['posterior'].stan_variable("lambda_tilde")[1].flatten()
tau_samples = prior_N100_fits['Regularized Horseshoe']['posterior'].stan_variable("tau")[1]

In [None]:
import numpy as np

P = P_RHS[1]
S = S_RHS[1]

# Gitt: P (diagonal positiv), S (symmetrisk)
p = np.diag(P)                       # diagonalene i P
P_inv_sqrt = np.diag(1.0/np.sqrt(p))          # P^{-1/2}
W = P_inv_sqrt @ S @ P_inv_sqrt                        # whitened

# Symmetrisk EVD
r, U = np.linalg.eigh(W)             # r = egenverdier (stigende), U kolonner = egenvektorer

inv_lambda2 = 1.0 / (lambda_samples**2)
# λ_eff_i^2 = 1 / sum_j (U_{ji}^2 / λ_j^2)
lambda_eff_sq = 1.0 / (U**2 @ inv_lambda2)
lambda_eff = np.sqrt(lambda_eff_sq)  # shape (160,)

lambda_eff.shape, lambda_eff[:5]


In [None]:
print(np.median(lambda_eff), np.median(reg_lambda_samples))

In [73]:
import numpy as np

lam_eff = lambda_eff  # your 160 values
s_hat = np.median(lam_eff)

# QQ points
p = (np.arange(1, len(lam_eff)+1) - 0.5) / len(lam_eff)
q_theory = s_hat * np.tan(0.5*np.pi*p)
q_emp = np.sort(lam_eff)

# q_emp vs q_theory ~ line y=x if Half-Cauchy(s_hat)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import halfcauchy

# --- QQ plot function ---
def qq_plot(data, label, color):
    n = len(data)
    p = (np.arange(1, n+1) - 0.5) / n
    q_theory = halfcauchy.ppf(p, scale=1)  # Half-Cauchy(0,1)
    q_emp = np.sort(data)
    plt.scatter(q_theory, q_emp, label=label, alpha=0.7, color=color)

# --- Tail plot function ---
def tail_plot(data, label, color):
    sorted_data = np.sort(data)
    n = len(data)
    surv_emp = np.arange(n, 0, -1) / n  # empirical survival
    plt.plot(sorted_data, sorted_data * surv_emp, label=label, color=color)

# -----------------------------
# QQ plot
plt.figure(figsize=(6,6))
qq_plot(lambda_eff, "lambda_eff", "C0")
qq_plot(reg_lambda_samples, "lambda_draws", "C1")
lims = [0, max(np.max(lambda_eff), np.max(reg_lambda_samples), 10)]
plt.plot(lims, lims, 'k--', lw=1, label="y=x")
plt.xlabel("Theoretical Half-Cauchy(0,1) quantiles")
plt.ylabel("Empirical quantiles")
plt.title("QQ-plot vs Half-Cauchy(0,1)")
plt.legend()
plt.grid(True)
plt.show()

# -----------------------------
# Tail plot
x = np.linspace(0.1, 10, 200)
surv_theory = 1 - halfcauchy.cdf(x, scale=1)

plt.figure(figsize=(7,5))
tail_plot(lambda_eff, "lambda_eff", "C0")
tail_plot(reg_lambda_samples, "lambda_draws", "C1")
plt.plot(x, x * surv_theory, 'k--', lw=1, label="Half-Cauchy(0,1)")
plt.xlabel("x")
plt.ylabel("x * Survival(x)")
plt.title("Tail diagnostic: x * P(X>x)")
plt.legend()
plt.grid(True)
plt.show()


## Build linearized $\bar{w}$

In [32]:
def compute_linearized_mean(
    X, y,
    W_all, b1_all, b2_all, v_all,          # (D,H,p), (D,H), (D,), (D,H)
    sigma_all, tau_w_all, tau_v_all,       # (D,), (D,), (D,)
    lambda_all,                            # (D,H,p)
    activation="tanh",
    return_mats=True,
):
    """
    Per draw d, compute:
      - R_d, P_d, S_d, Sigma_y_d, J_d from your local function
      - y*_d = (y - b2_d*1) + J_d @ vec(W0_d)
      - g_d  = J_d^T (Sigma_y_d^{-1} y*_d)  [via solve]
      - bar_w_d = (P_d + S_d)^{-1} g_d      [via solve]

    Returns:
      R_stack    : (D, N, N)  (None if return_mats=False)
      w_bar_stack: (D, N)
    """
    import numpy as np

    D, H, p = W_all.shape
    N = H * p
    n = y.shape[0]

    R_stack = np.empty((D, N, N)) if return_mats else None
    w_bar_stack = np.empty((D, N))

    # ensure 1-D copy of y
    y = np.asarray(y, dtype=float).reshape(n)

    for d in range(D):
        # Unchanged local call; MUST return (R, P, S, Sigma_y, J, H)
        R, P, S, Sigma_y, J, _ = compute_shrinkage_for_W_block(
            X=X,
            W0=W_all[d],
            b0=b1_all[d],
            v0=v_all[d],
            sigma=float(sigma_all[d]),
            tau_w=float(tau_w_all[d]),
            tau_v=float(tau_v_all[d]),
            lambda_tilde=lambda_all[d],
            activation=activation,
        )

        # --- Build y* = (y - c*1) + J @ vec(W0) ---
        c_d = float(b2_all[d])                 # output bias per draw (scalar)
        z = y - c_d                            # centered response
        w0_vec = W_all[d].reshape(-1)          # (N,)
        y_star = z + J @ w0_vec                # (n,)

        # --- Compute g = J^T Sigma_y^{-1} y* using solves (stable) ---
        # r = Sigma_y^{-1} y*
        r = np.linalg.solve(Sigma_y, y_star)   # (n,)
        g = J.T @ r                             # (N,)

        # --- bar_w = (P + S)^{-1} g  (stable solve; no explicit inverse) ---
        bar_w = np.linalg.solve(P + S, g)      # (N,)

        if return_mats:
            R_stack[d] = R

        w_bar_stack[d] = bar_w

    return R_stack, w_bar_stack

W, b1, v, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Gaussian tanh')

R_stack_gauss, w_bar_stack_gauss = compute_linearized_mean(X, y, W, b1, b2, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")

W, b1, v, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Regularized Horseshoe tanh')

R_stack_RHS, w_bar_stack_RHS = compute_linearized_mean(X, y, W, b1, b2, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")

W, b1, v, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Dirichlet Horseshoe tanh')

R_stack_DHS, w_bar_stack_DHS = compute_linearized_mean(X, y, W, b1, b2, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")

W, b1, v, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Dirichlet Student T tanh')

R_stack_DST, w_bar_stack_DST = compute_linearized_mean(X, y, W, b1, b2, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")



In [33]:
W_all_gauss = posterior_N100_fits['Gaussian tanh']['posterior'].stan_variable("W_1")
v_all_gauss = posterior_N100_fits['Gaussian tanh']['posterior'].stan_variable("W_L")

W_all_RHS = posterior_N100_fits['Regularized Horseshoe tanh']['posterior'].stan_variable("W_1")
v_all_RHS = posterior_N100_fits['Regularized Horseshoe tanh']['posterior'].stan_variable("W_L")

W_all_DHS = posterior_N100_fits['Dirichlet Horseshoe tanh']['posterior'].stan_variable("W_1")
v_all_DHS = posterior_N100_fits['Dirichlet Horseshoe tanh']['posterior'].stan_variable("W_L")

W_all_DST = posterior_N100_fits['Dirichlet Student T tanh']['posterior'].stan_variable("W_1")
v_all_DST = posterior_N100_fits['Dirichlet Student T tanh']['posterior'].stan_variable("W_L")

In [None]:
def align_and_compare(W_all, v_all, w_bar_stack, sort_key="abs_v"):
    """
    Align signs & permutations across draws before comparing linearized mean with posterior mean.

    Inputs
    ------
    W_all        : array-like, shape (D, H, p) or (D, p, H) or with stray singleton dims.
    v_all        : array-like, shape (D, H) or (D, H, 1) or similar (length H per draw).
    w_bar_stack  : array-like, shape (D, H*p) OR (D, H, p) OR (D, 1, H*p), etc.

    Returns
    -------
    W_fix        : (D, H, p)   sign/permutation aligned
    v_fix        : (D, H)
    wbar_fix     : (D, H, p)
    summary      : dict with RMSE, Corr, CosSim, SignAgree (means vs means in aligned basis)
    """
    import numpy as np

    W_all = np.asarray(W_all)
    v_all = np.asarray(v_all)
    w_bar_stack = np.asarray(w_bar_stack)

    D = W_all.shape[0]

    # --- infer H from v (source of truth) ---
    v0 = np.squeeze(v_all[0]).ravel()
    H = v0.size
    if H == 0:
        raise ValueError("v_all[0] seems empty; cannot infer H.")
    # infer p from w_bar_stack length
    wb0 = np.squeeze(w_bar_stack[0]).ravel()
    if wb0.size % H != 0:
        # fallback: try infer p from W_all[0] after squeezing
        W0 = np.squeeze(W_all[0])
        if W0.ndim != 2:
            # try to drop any singleton dims
            W0 = W0.reshape([s for s in W0.shape if s != 1])
        if W0.ndim != 2:
            raise ValueError(f"Cannot infer (H,p). v length={H}, but w_bar_stack[0] has {wb0.size} elems "
                             f"and W_all[0] has shape {np.squeeze(W_all[0]).shape}.")
        h, p_candidate = W0.shape
        if h != H and p_candidate == H:
            p = h
        else:
            p = p_candidate
    else:
        p = wb0.size // H

    N = H * p

    # alloc outputs
    W_fix = np.empty((D, H, p), dtype=float)
    v_fix = np.empty((D, H), dtype=float)
    wbar_fix = np.empty((D, H, p), dtype=float)

    def coerce_W(Wd, H, p):
        """Return Wd as (H,p). Accepts (H,p), (p,H), or with singleton dims."""
        A = np.asarray(Wd, dtype=float)
        A = np.squeeze(A)
        if A.ndim == 2:
            h, q = A.shape
            if h == H and q == p:
                return A
            if h == p and q == H:
                return A.T
            # If one matches H, try reshape to (H, -1)
            if h == H and h*q == H*p:
                return A.reshape(H, p)
            if q == H and h*q == H*p:
                return A.T.reshape(H, p)
            raise ValueError(f"Cannot coerce W of shape {A.shape} to (H,p)=({H},{p}).")
        elif A.ndim == 3 and 1 in A.shape:
            # squeeze singleton and recurse
            return coerce_W(np.squeeze(A), H, p)
        else:
            raise ValueError(f"Unexpected W ndim={A.ndim}, shape={A.shape}")

    def coerce_v(vd, H):
        """Return vd as (H,)"""
        v = np.asarray(vd, dtype=float).squeeze().ravel()
        if v.size != H:
            raise ValueError(f"v has size {v.size}, expected H={H}.")
        return v

    def coerce_wbar_row(wbd, H, p):
        """Return wbar row as (H,p) from (N,) or already (H,p)."""
        w = np.asarray(wbd, dtype=float).squeeze().ravel()
        if w.size == H * p:
            return w.reshape(H, p)
        # already 2D?
        W2 = np.asarray(wbd, dtype=float).squeeze()
        if W2.ndim == 2 and W2.shape == (H, p):
            return W2
        raise ValueError(f"w_bar row has {w.size} elems but H*p={H*p} and not (H,p).")

    for d in range(D):
        # coerce shapes
        Wd = coerce_W(W_all[d], H, p)          # (H,p)
        vd = coerce_v(v_all[d], H)             # (H,)
        wbd = coerce_wbar_row(w_bar_stack[d], H, p)

        # 1) sign fix so v >= 0
        s = np.sign(vd)
        s[s == 0.0] = 1.0
        Wd = Wd * s[:, None]
        wbd = wbd * s[:, None]
        vd = np.abs(vd)

        # 2) permute units by a stable key
        if sort_key == "abs_v":
            idx = np.argsort(-vd)  # descending |v|
        elif sort_key == "abs_v_times_rownorm":
            idx = np.argsort(-(vd * np.linalg.norm(Wd, axis=1)))
        else:
            raise ValueError(f"Unknown sort_key: {sort_key}")

        W_fix[d] = Wd[idx]
        wbar_fix[d] = wbd[idx]
        v_fix[d] = vd[idx]

    # Compare means in aligned basis
    w_post_mean = W_fix.reshape(D, -1).mean(axis=0)   # (N,)
    w_lin_mean  = wbar_fix.reshape(D, -1).mean(axis=0)

    rmse = float(np.sqrt(np.mean((w_lin_mean - w_post_mean)**2)))
    corr = float(np.corrcoef(w_lin_mean, w_post_mean)[0, 1])
    cos  = float(np.dot(w_lin_mean, w_post_mean) /
                 (np.linalg.norm(w_lin_mean) * np.linalg.norm(w_post_mean)))
    sign_agree = float(np.mean(np.sign(w_lin_mean) == np.sign(w_post_mean)))

    summary = dict(RMSE=rmse, Corr=corr, CosSim=cos, SignAgree=sign_agree,
                   H=H, p=p, N=N)
    return W_fix, v_fix, wbar_fix, summary

active_cols   = np.arange(0, 5)     # adjust if your actives are different
inactive_cols = np.arange(5, 10)
def block_stats(A, B, cols):
    a = A[:, cols].ravel(); b = B[:, cols].ravel()
    return np.corrcoef(a, b)[0,1], np.sqrt(np.mean((a-b)**2))

W_fix, v_fix, wbar_fix, summary = align_and_compare(W_all_gauss, v_all_gauss, w_bar_stack_gauss, sort_key="abs_v")
print("Gaussian: \n", summary)
w_post_mean = W_fix.reshape(W_fix.shape[0], -1).mean(axis=0)
w_lin_mean  = wbar_fix.reshape(wbar_fix.shape[0], -1).mean(axis=0)
nrmse = np.linalg.norm(w_lin_mean - w_post_mean) / np.linalg.norm(w_post_mean)
r2 = 1 - np.sum((w_lin_mean - w_post_mean)**2) / np.sum((w_post_mean - w_post_mean.mean())**2)
print(f"nRMSE: {nrmse:.3f}, R^2: {r2:.3f} \n")
W_fix, v_fix, wbar_fix, summary = align_and_compare(W_all_RHS, v_all_RHS, w_bar_stack_RHS, sort_key="abs_v")

H, p = summary["H"], summary["p"]
lin = wbar_fix.mean(axis=0).reshape(H, p)
post = W_fix.mean(axis=0).reshape(H, p)

print("Active  -> Corr, RMSE:", block_stats(lin, post, active_cols))
print("Inactive-> Corr, RMSE:", block_stats(lin, post, inactive_cols), "\n")

print("RHS: \n", summary)
w_post_mean = W_fix.reshape(W_fix.shape[0], -1).mean(axis=0)
w_lin_mean  = wbar_fix.reshape(wbar_fix.shape[0], -1).mean(axis=0)
nrmse = np.linalg.norm(w_lin_mean - w_post_mean) / np.linalg.norm(w_post_mean)
r2 = 1 - np.sum((w_lin_mean - w_post_mean)**2) / np.sum((w_post_mean - w_post_mean.mean())**2)
print(f"nRMSE: {nrmse:.3f}, R^2: {r2:.3f} \n")
W_fix, v_fix, wbar_fix, summary = align_and_compare(W_all_DHS, v_all_DHS, w_bar_stack_DHS, sort_key="abs_v")

H, p = summary["H"], summary["p"]
lin = wbar_fix.mean(axis=0).reshape(H, p)
post = W_fix.mean(axis=0).reshape(H, p)

print("Active  -> Corr, RMSE:", block_stats(lin, post, active_cols))
print("Inactive-> Corr, RMSE:", block_stats(lin, post, inactive_cols), "\n")

print("DHS: \n", summary)
w_post_mean = W_fix.reshape(W_fix.shape[0], -1).mean(axis=0)
w_lin_mean  = wbar_fix.reshape(wbar_fix.shape[0], -1).mean(axis=0)
nrmse = np.linalg.norm(w_lin_mean - w_post_mean) / np.linalg.norm(w_post_mean)
r2 = 1 - np.sum((w_lin_mean - w_post_mean)**2) / np.sum((w_post_mean - w_post_mean.mean())**2)
print(f"nRMSE: {nrmse:.3f}, R^2: {r2:.3f} \n")
W_fix, v_fix, wbar_fix, summary = align_and_compare(W_all_DST, v_all_DST, w_bar_stack_DST, sort_key="abs_v")

H, p = summary["H"], summary["p"]
lin = wbar_fix.mean(axis=0).reshape(H, p)
post = W_fix.mean(axis=0).reshape(H, p)

print("Active  -> Corr, RMSE:", block_stats(lin, post, active_cols))
print("Inactive-> Corr, RMSE:", block_stats(lin, post, inactive_cols), "\n")

print("DST: \n", summary)
w_post_mean = W_fix.reshape(W_fix.shape[0], -1).mean(axis=0)
w_lin_mean  = wbar_fix.reshape(wbar_fix.shape[0], -1).mean(axis=0)
nrmse = np.linalg.norm(w_lin_mean - w_post_mean) / np.linalg.norm(w_post_mean)
r2 = 1 - np.sum((w_lin_mean - w_post_mean)**2) / np.sum((w_post_mean - w_post_mean.mean())**2)
print(f"nRMSE: {nrmse:.3f}, R^2: {r2:.3f} \n")

H, p = summary["H"], summary["p"]
lin = wbar_fix.mean(axis=0).reshape(H, p)
post = W_fix.mean(axis=0).reshape(H, p)

print("Active  -> Corr, RMSE:", block_stats(lin, post, active_cols))
print("Inactive-> Corr, RMSE:", block_stats(lin, post, inactive_cols), "\n")

