In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
#from sklearn.metrics import mean_squared_errosr
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
results_dir_priors = "results/priors/single_layer/tanh/friedman"
results_dir_posteriors = "results/regression/single_layer/tanh/friedman"

prior_names = ["Dirichlet Horseshoe", "Regularized Horseshoe", "Dirichlet Student T", "Gaussian"]
posterior_names = ["Dirichlet Horseshoe tanh", "Regularized Horseshoe tanh", "Dirichlet Student T tanh", "Gaussian tanh"]


prior_N100_fits = get_model_fits(
    config="Friedman_N100_p10_sigma1.00_seed1",
    results_dir=results_dir_priors,
    models=prior_names,
    include_prior=False,
)

prior_N200_fits = get_model_fits(
    config="Friedman_N200_p10_sigma1.00_seed2",
    results_dir=results_dir_priors,
    models=prior_names,
    include_prior=False,
)

prior_N500_fits = get_model_fits(
    config="Friedman_N500_p10_sigma1.00_seed11",
    results_dir=results_dir_priors,
    models=prior_names,
    include_prior=False,
)
    
posterior_N100_fits = get_model_fits(
    config="Friedman_N100_p10_sigma1.00_seed1",
    results_dir=results_dir_posteriors,
    models=posterior_names,
    include_prior=False,
)

posterior_N200_fits = get_model_fits(
    config="Friedman_N200_p10_sigma1.00_seed2",
    results_dir=results_dir_posteriors,
    models=posterior_names,
    include_prior=False,
)

posterior_N500_fits = get_model_fits(
    config="Friedman_N500_p10_sigma1.00_seed11",
    results_dir=results_dir_posteriors,
    models=posterior_names,
    include_prior=False,
)


In [3]:
path = "datasets/friedman/Friedman_N500_p10_sigma1.00_seed11.npz"
data = np.load(path)
X = data['X_train']

In [32]:
import numpy as np
from typing import Tuple, Callable

# ---------- Aktivasjon og deriverte ----------

def get_activation(activation: str = "tanh") -> Tuple[Callable, Callable]:
    if activation == "tanh":
        phi = np.tanh
        def dphi(a): return 1.0 - np.tanh(a)**2
    elif activation == "relu":
        def phi(a): return np.maximum(0.0, a)
        def dphi(a): return (a > 0.0).astype(a.dtype)
    else:
        raise ValueError(f"Unsupported activation: {activation}")
    return phi, dphi

# ---------- H(w0) og J_W(w0, v0) ----------

def build_hidden_and_jacobian_W(
    X: np.ndarray,               # (n, p)
    W0: np.ndarray,              # (H, p)  -- vekter i referansepunktet w0
    b0: np.ndarray,              # (H,)    -- bias i referansepunktet w0
    v0: np.ndarray,              # (H,)    -- utgangsvekter i referansepunktet v0
    activation: str = "tanh",
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Returnerer:
      H  : (n, H)         = H(w0)
      JW : (n, H*p)       = d(H(w)v)/d vec(W) |_(w0, v0), kolonner ordnet som (h=0..H-1, j=0..p-1)
    """
    n, p = X.shape
    H, pW = W0.shape
    assert pW == p
    phi, dphi = get_activation(activation)

    # Pre- og post-aktivert
    A = X @ W0.T + b0[None, :]        # (n, H), a_{i,h}
    Hmat = phi(A)                     # (n, H), h_{i,h}
    dphiA = dphi(A)                   # (n, H)

    # J_W: df/dW_{h,j} = v_h * dphi(a_{i,h}) * x_{i,j}
    # For hver node h bygger vi et (n, p)-bidrag og flater ut langs j, og stabler så langs h.
    JW_blocks = []
    for h in range(H):
        # (n,1) * (1,p) -> (n,p)
        block_h = (v0[h] * dphiA[:, [h]]) * X 
        JW_blocks.append(block_h.reshape(n, p))
    # Stack kolonnevis i rekkefølge (h, j) -> (n, H*p)
    JW = np.hstack([B for B in JW_blocks])
    return Hmat, JW

# ---------- Sigma_y og P ----------

def build_Sigma_y(
    Hmat: np.ndarray,     # (n, H) = H(w0)
    tau_v: float,         # prior std for v
    sigma: float          # støy std i likelihood
) -> np.ndarray:
    """
    Σ_y = τ_v^2 H H^T + σ^2 I_n
    """
    n = Hmat.shape[0]
    return (tau_v**2) * (Hmat @ Hmat.T) + (sigma**2) * np.eye(n)

def build_P_from_lambda_tau(
    lambda_tilde: np.ndarray,  # (H, p) lokale skalaer for W
    tau_w: float               # global skala for w
) -> np.ndarray:
    """
    P = τ_w^{-2} Λ^{-1} der Λ = diag(λ^2) for konsistens med uttrykket 1/(1 + τ^2 λ^2 s).
    Dvs. diag(P) = 1 / (τ_w^2 * λ^2).
    Returnerer P som (H*p, H*p) diagonalmatrise.
    """
    lam_vec = lambda_tilde.reshape(-1)          # (H*p,)
    diagP = 1.0 / ( (tau_w**2) * (lam_vec) ) # (H*p,)
    return np.diag(diagP)

# ---------- S, shrinkage-matrise R = (P+S)^{-1} P ----------

def build_S(JW: np.ndarray, Sigma_y: np.ndarray) -> np.ndarray:
    """
    S = J_W^T Σ_y^{-1} J_W  (Hp x Hp).
    Løser via lineær solve for stabilitet: X = Σ_y^{-1} J_W = solve(Σ_y, J_W).
    """
    X = np.linalg.solve(Sigma_y, JW)       # (n, Hp)
    return JW.T @ X                        # (Hp, Hp)

def shrinkage_matrix(P: np.ndarray, S: np.ndarray) -> np.ndarray:
    """
    R = (P+S)^{-1} P. Bruk Cholesky når mulig.
    Løser (P+S) * R = P for R.
    """
    A = P + S
    # Robust fallback hvis Cholesky feiler
    try:
        L = np.linalg.cholesky(A)
        # L Y = P  -> Y
        Y = np.linalg.solve(L, P)
        # L^T R = Y -> R
        R = np.linalg.solve(L.T, Y)
    except np.linalg.LinAlgError:
        R = np.linalg.solve(A, P)
    return R

def shrinkage_matrix_stable(P, S, jitter=0.0):
    """
    Stabil beregning av R = (P+S)^{-1} P via
    R = P^{1/2} (I + P^{-1/2} S P^{-1/2})^{-1} P^{1/2}.
    Krever at P er diagonal (positiv).
    """
    d = np.diag(P).astype(float)
    # Guardrails: ingen nuller/NaN/negativ
    eps = 1e-12
    d = np.clip(d, eps, np.finfo(float).max)
    Phalf    = np.diag(np.sqrt(d))
    Pinvhalf = np.diag(1.0 / np.sqrt(d))

    M = Pinvhalf @ S @ Pinvhalf
    # Jitter for SPD-sikkerhet (skader ikke i praksis)
    if jitter > 0:
        M = M + jitter * np.eye(M.shape[0])

    # (I + M) er SPD -> Cholesky
    I = np.eye(M.shape[0])
    L = np.linalg.cholesky(I + M)
    # (I+M)^{-1} P^{1/2} = (L^T)^{-1} (L)^{-1} P^{1/2}
    Z = np.linalg.solve(L, Phalf)
    W = np.linalg.solve(L.T, Z)
    # R = P^{1/2} * W
    R = Phalf @ W
    # Symmetrer (numerisk)
    R = 0.5 * (R + R.T)
    return R

def shrinkage_eigs_and_df(P, S):
    """Returner r-eigenverdier og df_eff i P-whitnede koordinater."""
    d = np.diag(P).astype(float)
    eps = 1e-12
    Pinvhalf = np.diag(1.0 / np.sqrt(np.maximum(d, eps)))

    M = Pinvhalf @ S @ Pinvhalf          # SPD
    mu = np.linalg.eigvalsh(M)           # >= 0
    r = 1.0 / (1.0 + mu)                 # i (0,1]
    df_eff = np.sum(1.0 - r)             # = sum mu/(1+mu) >= 0
    return r, df_eff


# ---------- (Valgfritt) shrinket "hat w" gitt y*, hvis du vil sjekke tallene ----------

# def extract_model_draws(fit_dict, model: str):
#     """
#     Returns ALL draws stacked:
#       W_all      : (D, H, p)
#       b_all      : (D, H)
#       v_all      : (D, H)
#       c_all      : (D,)
#       sigma_all  : (D,)
#       tau_w_all  : (D,)
#       tau_v_all  : (D,)   (ones if not in fit)
#       lambda_all : (D, H, p)
#     """
#     post = fit_dict[model]['posterior']

#     # W_1: (D, p, H) -> (D, H, p)
#     W_1 = np.asarray(post.stan_variable("W_1"))
#     W_all = np.transpose(W_1, (0, 2, 1))              # (D, H, p)
#     D, H, p = W_all.shape

#     # W_L: (D, H, 1) -> (D, H)
#     W_L = np.asarray(post.stan_variable("W_L"))
#     v_all = W_L.reshape(D, -1)                        # (D, H)

#     # hidden_bias: (D, 1, H) -> (D, H)
#     b_1 = np.asarray(post.stan_variable("hidden_bias"))
#     b_all = b_1.reshape(D, -1)                        # (D, H)

#     # output_bias: (D, 1) -> (D,)
#     b_2 = np.asarray(post.stan_variable("output_bias"))
#     c_all = b_2.reshape(D)                            # (D,)

#     # sigma: (D,)
#     sigma_all = np.asarray(post.stan_variable("sigma")).reshape(D)

#     # tau_w: (D,)
#     if model in ("Gaussian", "Gaussian tanh"):
#         tau_w_all = np.ones(D)
#         tau_v_all = np.ones(D)
#         lambda_all = np.ones((D, H, p))
#     else:
#         tau_w_all = np.asarray(post.stan_variable("tau")).reshape(D)
#         # tau_v may not exist; default to ones
#         try:
#             tau_v_all = np.asarray(post.stan_variable("tau_v")).reshape(D)
#         except Exception:
#             tau_v_all = np.ones(D)

#         # lambda_tilde: (D, H, p)  or  (D, p, H)
#         lam_name = "lambda_tilde" if model in ("Regularized Horseshoe","Regularized Horseshoe tanh") else "lambda_tilde_data"
#         lam = np.asarray(post.stan_variable(lam_name))
#         if lam.shape[1:] == (H, p):
#             lambda_all = lam
#         else:
#             # assume (D, p, H) -> (D, H, p)
#             lambda_all = np.transpose(lam, (0, 2, 1))

#     return W_all, b_all, v_all, c_all, sigma_all, tau_w_all, tau_v_all, lambda_all

def extract_model_draws(fit_dict, model: str):
    """
    Returnerer ALL draws, med 'lambda_all' definert som EFFEKTIV VARIANSFAKTOR per vekt:
      Gaussian:                 lambda_all = 1
      Regularized Horseshoe:    lambda_all = lambda_tilde
      Dirichlet (DHS/DST):      lambda_all = lambda_tilde_data * phi_data

    Shapes:
      W_all      : (D, H, p)
      b_all      : (D, H)
      v_all      : (D, H)
      c_all      : (D,)
      sigma_all  : (D,)
      tau_w_all  : (D,)
      tau_v_all  : (D,)   (ones if not in fit)
      lambda_all : (D, H, p)  <-- effektiv variansfaktor
    """
    post = fit_dict[model]['posterior']

    # W_1: (D, p, H) -> (D, H, p)
    W_1 = np.asarray(post.stan_variable("W_1"))
    W_all = np.transpose(W_1, (0, 2, 1))
    D, H, p = W_all.shape

    # W_L: (D, H, out_nodes=1) -> (D, H)
    W_L = np.asarray(post.stan_variable("W_L"))
    v_all = W_L.reshape(D, -1)

    # hidden_bias: (D, 1, H) -> (D, H)
    b_1 = np.asarray(post.stan_variable("hidden_bias"))
    b_all = b_1.reshape(D, -1)

    # output_bias: (D, 1) -> (D,)
    b_2 = np.asarray(post.stan_variable("output_bias"))
    c_all = b_2.reshape(D)

    # sigma: (D,)
    sigma_all = np.asarray(post.stan_variable("sigma")).reshape(D)

    # Modell-flagg
    is_gauss      = ("Gaussian" in model)
    is_rhs        = ("Regularized Horseshoe" in model)
    is_dirichlet  = ("Dirichlet" in model) or ("DST" in model)

    # tau_w / tau_v
    if is_gauss:
        tau_w_all = np.ones(D)
        tau_v_all = np.ones(D)
    else:
        tau_w_all = np.asarray(post.stan_variable("tau")).reshape(D)
        try:
            tau_v_all = np.asarray(post.stan_variable("tau_v")).reshape(D)
        except Exception:
            tau_v_all = np.ones(D)

    # Effektiv lokal variansfaktor lambda_all
    if is_gauss:
        lambda_all = np.ones((D, H, p))
    else:
        lam_name = "lambda_tilde" if is_rhs else "lambda_tilde_data"
        lam = np.asarray(post.stan_variable(lam_name))
        # Bring til (D, H, p)
        if lam.shape[1:] == (H, p):
            lam_var = lam
        else:
            lam_var = np.transpose(lam, (0, 2, 1))

        if is_dirichlet:
            # phi_data: sannsynlige shapes (D, H, p) eller (D, p, H)
            phi = np.asarray(post.stan_variable("phi_data"))
            if phi.shape[1:] == (H, p):
                phi_hp = phi
            else:
                phi_hp = np.transpose(phi, (0, 2, 1))
            # Stan: stddev = tau * sqrt(lambda_tilde) * sqrt(phi)
            #  => var = tau^2 * lambda_tilde * phi
            lambda_all = lam_var * phi_hp
        else:
            # RHS: var = tau^2 * lambda_tilde
            lambda_all = lam_var

    return W_all, b_all, v_all, c_all, sigma_all, tau_w_all, tau_v_all, lambda_all

# ------- Knyt alt sammen -------

def compute_shrinkage_for_W_block(
    X: np.ndarray,
    W0: np.ndarray, b0: np.ndarray, v0: np.ndarray,
    sigma: float, tau_w: float, tau_v: float,
    lambda_tilde: np.ndarray,
    activation: str = "tanh"
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Returnerer (R, P, S, Sigma_y) der R = (P+S)^{-1} P for W-blokken.
    """
    Hmat, JW = build_hidden_and_jacobian_W(X, W0, b0, v0, activation=activation)  # (n,H), (n,Hp)
    Sigma_y = build_Sigma_y(Hmat, tau_v=tau_v, sigma=sigma)                       # (n,n)
    P = build_P_from_lambda_tau(lambda_tilde, tau_w=tau_w)                        # (Hp,Hp)
    S = build_S(JW, Sigma_y)                                                      # (Hp,Hp)
    R = shrinkage_matrix_stable(P, S)                                                    # (Hp,Hp)
    return R, P, S, Sigma_y


def compute_shrinkage(
    X,
    W_all, b_all, v_all,          # (D,H,p), (D,H), (D,H)
    sigma_all, tau_w_all, tau_v_all,  # (D,), (D,), (D,)
    lambda_all,                   # (D,H,p)
    activation="tanh",
    return_mats=True,             # set False if you only want summaries
):
    """
    Loop over draws and compute R=(P+S)^{-1}P per draw using your single-draw function.
    Returns:
      R_stack : (D, N, N) with N=H*p  (if return_mats=True, else None)
      r_eigs  : (D, N)  sorted eigenvalues in [0,1]
      df_eff  : (D,)    effective dof = tr(I-R) = N - tr(R)
    """
    D, H, p = W_all.shape
    N = H * p

    R_stack = np.empty((D, N, N)) if return_mats else None
    P_stack = np.empty((D, N, N)) if return_mats else None
    r_eigs  = np.empty((D, N))
    df_eff  = np.empty(D)

    for d in range(D):
        R, P, S, Sigma_y = compute_shrinkage_for_W_block(
            X=X,
            W0=W_all[d],
            b0=b_all[d],
            v0=v_all[d],
            sigma=float(sigma_all[d]),
            tau_w=float(tau_w_all[d]),
            tau_v=float(tau_v_all[d]),
            lambda_tilde=lambda_all[d],
            activation=activation,
        )

        if return_mats:
            R_stack[d] = R
            P_stack[d] = P

        r, df = shrinkage_eigs_and_df(P, S)
        r_eigs[d] = np.sort(r)
        df_eff[d] = df

    return R_stack, P_stack, r_eigs, df_eff


# ---------- Minimal kjøreeksempel ----------
#draw = 0
W, b1, v, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Gaussian tanh')

R_gauss, P_gauss, eigs_gauss, df_eff_gauss = compute_shrinkage(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")

#R_gauss, P_gauss, S_gauss, Sigma_y_gauss = compute_shrinkage_for_W_block(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")

W, b1, v, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Regularized Horseshoe tanh')
R_RHS, P_RHS, eigs_RHS, df_eff_RHS = compute_shrinkage(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")
# R_RHS, P_RHS, S_RHS, Sigma_y_RHS = compute_shrinkage_for_W_block(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")

W, b1, v, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Dirichlet Horseshoe tanh')
R_DHS, P_DHS, eigs_DHS, df_eff_DHS = compute_shrinkage(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")
# R_DHS, P_DHS, S_DHS, Sigma_y_DHS = compute_shrinkage_for_W_block(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")

W, b1, v, b2, sigma, tau_w, tau_v, lambda_tilde = extract_model_draws(posterior_N100_fits, model='Dirichlet Student T tanh')
R_DST, P_DST, eigs_DST, df_eff_DST = compute_shrinkage(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")
# R_DST, P_DST, S_DST, Sigma_y_DST = compute_shrinkage_for_W_block(X, W, b1, v, sigma, tau_w, tau_v, lambda_tilde, activation="tanh")


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# ------- summaries -------
def summarize_df(df):
    df = np.asarray(df).reshape(-1)
    return {
        "mean": float(np.mean(df)),
        "sd":   float(np.std(df, ddof=1)),
        "q10":  float(np.quantile(df, 0.10)),
        "q50":  float(np.quantile(df, 0.50)),
        "q90":  float(np.quantile(df, 0.90)),
    }

def summarize_eigs(eigs, small=0.1, large=0.9):
    # eigs: (D,N) of r \in [0,1]
    r = np.asarray(eigs).reshape(-1)
    return {
        "median_r": float(np.median(r)),
        "q10_r":    float(np.quantile(r, 0.10)),
        "q90_r":    float(np.quantile(r, 0.90)),
        "frac_small(<{:.2f})".format(small): float(np.mean(r < small)),
        "frac_large(>{:.2f})".format(large): float(np.mean(r > large)),
    }

# ------- ECDF plotting helpers -------

def plot_hist_eigs(model_eigs_dict, bins=50, title=""):
    """
    model_eigs_dict: {"Name": eigs}, where eigs has shape (D, N) or (D*N,)
                     and contains r-eigenvalues in [0, 1].
    """
    plt.figure(figsize=(9,5))

    # common bins across models for fair comparison
    if isinstance(bins, int):
        bins = np.linspace(0.0, 1.0, bins+1)

    for name, eigs in model_eigs_dict.items():
        x = np.asarray(eigs).ravel()
        x = np.clip(x, 0.0, 1.0)               # safety
        plt.hist(x, bins=bins, density=True, alpha=0.5, label=name)

    plt.xlabel("r eigenvalues (0 = no shrink, 1 = hard shrink)")
    plt.ylabel("Density")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()


# ------- optional: rho transform (prior/data ratio) -------
def plot_hist_log1p_rho(model_eigs_dict, bins=50, eps=1e-12, title=""):
    """
    model_eigs_dict: {"Name": eigs}, eigs shape (D,N) or (D*N,)
                     r-eigenvalues in [0,1].
    Plots overlaid histograms of log(1 + rho), rho = r/(1-r).
    """
    # flatten & transform for pooled bin edges
    pooled = []
    for eigs in model_eigs_dict.values():
        r = np.asarray(eigs).ravel()
        r = np.clip(r, 0.0, 1.0 - eps)        # avoid division by zero at r=1
        rho = r / (1.0 - r)
        pooled.append(np.log1p(rho))
    pooled = np.concatenate(pooled)

    # common bins (use pooled range, cap extreme tail)
    if isinstance(bins, int):
        hi = float(np.quantile(pooled, 0.995))  # ignore extreme 0.5% tail
        bins = np.linspace(0.0, hi, bins + 1)

    plt.figure(figsize=(9,5))
    for name, eigs in model_eigs_dict.items():
        r = np.asarray(eigs).ravel()
        r = np.clip(r, 0.0, 1.0 - eps)
        rho = r / (1.0 - r)
        x = np.log1p(rho)
        plt.hist(x, bins=bins, density=True, alpha=0.5, label=name)

    plt.xlabel("log(1 + rho),  rho = r/(1-r)")
    plt.ylabel("Density")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()


# ------- run summaries -------
models_eigs = {
    "Gaussian": eigs_gauss,
    "RHS":      eigs_RHS,
    "DHS":      eigs_DHS,
    "DST":      eigs_DST,
}
models_df = {
    "Gaussian": df_eff_gauss,
    "RHS":      df_eff_RHS,
    "DHS":      df_eff_DHS,
    "DST":      df_eff_DST,
}

print("df_eff summaries:")
for name, df in models_df.items():
    print(name, summarize_df(df))

print("\nEigenvalue summaries:")
for name, eigs in models_eigs.items():
    print(name, summarize_eigs(eigs, small=0.1, large=0.9))

# Plots to visualize differences in shrinkage
plot_hist_eigs(models_eigs, title="Histogram of shrinkage eigenvalues r")
plot_hist_log1p_rho(models_eigs, title="Histogram of log(1+rho)")


In [None]:

import matplotlib.pyplot as plt
from matplotlib.colors import SymLogNorm, Normalize
import numpy as np

def add_block_grid(ax, H, p, color="w", lw=0.5):
    Hp = H*p
    for h in range(1, H):
        k = h*p
        ax.axhline(k-0.5, color=color, lw=lw)
        ax.axvline(k-0.5, color=color, lw=lw)

def visualize_models(
    matrices, names, H=16, p=10, use_abs=False, cmap="magma",
    q_low=0.05, q_high=0.95
):
    """
    Viser heatmaps av matriser med felles, robust fargeskala:
    vmin = q_low-kvantilen over ALLE matrisene
    vmax = q_high-kvantilen over ALLE matrisene
    Verdier utenfor [vmin, vmax] klippes til endene.
    """
    # valgfritt absoluttbeløp
    mats = [np.abs(M) if use_abs else M for M in matrices]

    # Samle alle endelige verdier
    all_vals = np.concatenate([M[np.isfinite(M)].ravel() for M in mats]) if mats else np.array([])

    if all_vals.size == 0:
        vmin, vmax = -1.0, 1.0
    else:
        vmin = float(np.quantile(all_vals, q_low))
        vmax = float(np.quantile(all_vals, q_high))
        if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
            # fallback hvis alt er likt/NaN
            m = float(np.nanmean(all_vals)) if np.isfinite(np.nanmean(all_vals)) else 0.0
            vmin, vmax = m - 1.0, m + 1.0

    norm = Normalize(vmin=vmin, vmax=vmax)

    # Figur
    fig, axes = plt.subplots(2, 2, figsize=(10, 10), dpi=150, constrained_layout=True)
    axes = axes.ravel()

    im = None
    for ax, M, title in zip(axes, mats, names):
        Mplot = np.clip(M, vmin, vmax)  # klipp utfor intervallet
        im = ax.imshow(Mplot, aspect='equal', interpolation='nearest', cmap=cmap, norm=norm)
        add_block_grid(ax, H, p)
        ax.set_title(title)
        ax.set_xlabel("Columns")
        ax.set_ylabel("Rows")

    # Felles colorbar
    if im is not None:
        fig.colorbar(im, ax=axes.tolist(), label=("|Value|" if use_abs else "Value"))

    # (valgfritt) print hva som faktisk ble brukt
    print(f"Felles kvantiler: vmin (q={q_low}) = {vmin:.6g}, vmax (q={q_high}) = {vmax:.6g}")

    plt.show()

# Example call:
matrices = [
    R_gauss.mean(axis=0), #[0],
    R_RHS.mean(axis=0), #[0],
    R_DHS.mean(axis=0), #[0],
    R_DST.mean(axis=0), #[0]
]

# matrices = [
#     R_gauss[1],
#     R_RHS[1],
#     R_DHS[1],
#     R_DST[1]
# ]
names = [
    "Shrinkage (Gauss)", #"Data shrinkage (Gauss)", "Prior shrinkage (Gauss)",
    "Shrinkage (RHS)",   #"Data shrinkage (RHS)",   "Prior shrinkage (RHS)",
    "Shrinkage (DHS)",   #"Data shrinkage (DHS)",   "Prior shrinkage (DHS)",
    "Shrinkage (DST)",   #"Data shrinkage (DST)",   "Prior shrinkage (DST)",
]

visualize_models(matrices, names, H=16, p=10, use_abs=False)
