In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import re
from sklearn.metrics import mean_squared_error
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
results_dir_experimental = "results/priors/single_layer/experiment"

names_and_configs = {"Gaussian": "gauss_plain", "Regularized Horseshoe": "reg_hs", 
                     "Dirichlet Horseshoe": "dir_hs", "Dirichlet Student T": "stud_t_df3"}
                     
fits = {}
    

for key, value in names_and_configs.items():
    experimental_fit = get_model_fits(
        config=value,
        results_dir=results_dir_experimental,
        models=key,
        include_prior=False,
    )
    print(key)
    print(value)
    fits[key] = experimental_fit


In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Choose the hidden node index to inspect
node_idx = 1  # e.g., first hidden node

def extract_W_for_node(cmdstan_mcmc, var_name="W_1", node=0):
    """Return array shape (n_draws, P) of weights feeding into a given node."""
    W = cmdstan_mcmc.stan_variable(var_name)   # shape (n_draws, P, H)
    return W[:, :, node]                       # (n_draws, P)

# Grab arrays for each prior from your dict
W_gauss = extract_W_for_node(fits['Gaussian']['Gaussian']['posterior'], node=node_idx)
W_reg_hs = extract_W_for_node(fits['Regularized Horseshoe']['Regularized Horseshoe']['posterior'], node=node_idx)
W_dir_hs = extract_W_for_node(fits['Dirichlet Horseshoe']['Dirichlet Horseshoe']['posterior'], node=node_idx)
W_dir_st = extract_W_for_node(fits['Dirichlet Student T']['Dirichlet Student T']['posterior'], node=node_idx)

models = {
    "Gaussian": W_gauss,
    "Reg. Horseshoe": W_reg_hs,
    "Dirichlet HS": W_dir_hs,
    "Dirichlet Student-t": W_dir_st,
}


In [None]:
def rank_profile(W):
    # sort |w| per draw, descending; then average by rank
    sorted_abs = np.sort(np.abs(W), axis=1)[:, ::-1]         # (n_draws, P)
    mean_rank = sorted_abs.mean(axis=0)
    q05 = np.quantile(sorted_abs, 0.05, axis=0)
    q95 = np.quantile(sorted_abs, 0.95, axis=0)
    return mean_rank, q05, q95

plt.figure(figsize=(7,4))
for name, W in models.items():
    mean_rank, q05, q95 = rank_profile(W)
    x = np.arange(1, W.shape[1]+1)
    plt.plot(x, mean_rank, marker='o', label=name)
plt.xlabel("Rank (1 = largest |w| in draw)")
plt.ylabel("Average |w|")
plt.title(f"Avg. sorted |weights| into node {node_idx}")
plt.legend(); plt.tight_layout(); plt.show()


In [None]:
def topk_curve(W):
    sq = W**2
    shares = sq / sq.sum(axis=1, keepdims=True)              # per-draw normalization
    # Average of top-k shares ≈ cumsum of mean ordered shares
    ordered = np.sort(shares, axis=1)[:, ::-1]
    return ordered.mean(axis=0).cumsum()

plt.figure(figsize=(7,4))
for name, W in models.items():
    c = topk_curve(W)
    plt.plot(np.arange(1, len(c)+1), c, marker='.', label = name)
plt.axhline(0.9, ls='--', lw=1, label='90%')
plt.axhline(0.95, ls='--', lw=1, label='95%')
plt.xlabel("k (top-k weights)"); plt.ylabel("E[sum of top-k shares]")
plt.title(f"Top-k coverage of w² shares – node {node_idx}")
plt.legend(); plt.tight_layout(); plt.show()


In [None]:
def winner_freq(W):
    winners = np.argmax(np.abs(W), axis=1)            # index of largest |w| per draw
    P = W.shape[1]
    counts = np.bincount(winners, minlength=P)
    return counts / counts.sum()

freq_df = pd.DataFrame({name: winner_freq(W) for name, W in models.items()})
freq_df.index = [f"input_{i}" for i in range(freq_df.shape[0])]
print(freq_df.round(3))


In [None]:
def gini(v):
    v = np.sort(v)
    n = v.size
    return (np.sum((2*np.arange(1, n+1) - n - 1) * v)) / (n * v.sum())

def gini_over_draws(W):
    sq = W**2
    shares = sq / sq.sum(axis=1, keepdims=True)
    return np.apply_along_axis(gini, 1, shares)

for name, W in models.items():
    g = gini_over_draws(W)
    print(f"{name}: mean Gini = {g.mean():.3f}   (0=uniform, 1=one-hot)")


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

EPS = 1e-12

def shares_from_W(W):
    """Row-normalized squared weights (per draw)."""
    sq = W**2
    return sq / (sq.sum(axis=1, keepdims=True) + EPS)

# --- Shannon entropy and friends ---
def shannon_entropy_over_draws(W):
    """
    Natural-log entropy per draw.
    Range: [0, ln P]. 0 if one-hot, ln P if uniform.
    """
    p = shares_from_W(W)
    return -(p * np.log(p + EPS)).sum(axis=1)

def norm_entropy_over_draws(W):
    """
    Normalized sparsity-style measure in [0,1]:
    0=uniform, 1=one-hot.
    """
    H = shannon_entropy_over_draws(W)
    P = W.shape[1]
    return 1.0 - H / (np.log(P) + EPS)

def perplexity_over_draws(W):
    """
    Effective count via entropy: exp(H) in [1, P].
    """
    H = shannon_entropy_over_draws(W)
    return np.exp(H)


# --- Convenience: summarize per model ---
def summarize_entropy_kl(models):
    rows = []
    for name, W in models.items():
        H = shannon_entropy_over_draws(W)
        Hn = 1 - H / (np.log(W.shape[1]) + EPS)
        PPX = np.exp(H)

        def m_ci(x):
            return np.mean(x), np.quantile(x, 0.05), np.quantile(x, 0.95)

        H_m, H_l, H_u = m_ci(H)
        Hn_m, Hn_l, Hn_u = m_ci(Hn)
        PPX_m, PPX_l, PPX_u = m_ci(PPX)

        rows.append({
            "model": name,
            "H (nats) mean": H_m, "H p05": H_l, "H p95": H_u,
            "H_norm mean": Hn_m, "H_norm p05": Hn_l, "H_norm p95": Hn_u,
            "Perplexity mean": PPX_m, "Perplexity p05": PPX_l, "Perplexity p95": PPX_u,
        })
    return pd.DataFrame(rows).set_index("model")

# Example: print a compact summary table
summary_df = summarize_entropy_kl(models)
print(summary_df[[
    "H_norm mean", "Perplexity mean"
]].round(3))

# Optional: quick visual comparing normalized entropy & KL across models
plt.figure(figsize=(6.5, 3.8))
x = np.arange(len(models))
bar_w = 0.35
Hn_means = [norm_entropy_over_draws(W).mean() for W in models.values()]
plt.bar(x - bar_w/2, Hn_means, width=bar_w, label="1 - H/ln P")
plt.xticks(x, list(models.keys()), rotation=15)
plt.ylim(0, 1)
plt.ylabel("Normalized (0=uniform, 1=one-hot)")
plt.title(f"Entropy & KL sparsity — node {node_idx}")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import rankdata

# generic: rank→uniform columnwise (works for any matrix A)
def empirical_copula_cols(A):
    n, P = A.shape
    U = np.zeros_like(A, dtype=float)
    for j in range(P):
        U[:, j] = rankdata(A[:, j], method="average") / (n + 1.0)
    return U

def tail_dependence_curve_from_matrix(A, i=0, j=1, u_grid=None):
    if u_grid is None:
        u_grid = np.linspace(0.50, 0.99, 25)
    U = empirical_copula_cols(A)
    ui, uj = U[:, i], U[:, j]
    lam = np.array([np.mean(uj[ui > u] > u) if np.any(ui > u) else np.nan for u in u_grid])
    return u_grid, lam

def plot_tail_dependence(models, i=0, j=1, u_grid=None):
    if u_grid is None:
        u_grid = np.linspace(0.50, 0.99, 25)

    curves = {}
    ymax = 0.0
    for name, W in models.items():
        A = np.abs(W)  # use |w|
        u, lam = tail_dependence_curve_from_matrix(A, i=i, j=j, u_grid=u_grid)
        curves[name] = (u, lam)
        ymax = max(ymax, np.nanmax(lam))
    baseline = 1.0 - u_grid
    ymax = max(ymax, baseline.max())

    fig, axes = plt.subplots(2, 2, figsize=(10, 9), sharex=True, sharey=True)
    axes = axes.ravel()
    for ax, (name, (u, lam)) in zip(axes, curves.items()):
        ax.plot(u, lam, marker='o', ms=3, lw=1, label=name)
        ax.plot(u_grid, baseline, linestyle='--', lw=1, label='independent baseline (1 - u)')
        ax.set_title(name)
        ax.set_ylim(0, min(1.0, ymax * 1.05))
        ax.grid(True, linewidth=0.4, alpha=0.4)
        ax.set_xlabel("u (upper-tail threshold)")
        ax.set_ylabel(r"$\lambda_U(u) = P(U_j>u \mid U_i>u)$")
        ax.legend(loc='upper right', fontsize=8)

    fig.suptitle(f"Upper-tail dependence — |w|, inputs {i} vs {j}", y=0.98)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

# Example:
plot_tail_dependence(models, i=0, j=1)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from numpy.random import default_rng
# assumes empirical_copula_cols() is already defined above
# assumes shares_from_W(W) is already defined in your session

def _gaussian_shares_baseline_lambdaU(P, i=0, j=1, u_grid=None, N=200_000, seed=0):
    rng = default_rng(seed)
    if u_grid is None:
        u_grid = np.linspace(0.50, 0.99, 25)
    alpha = np.full(P, 0.5, dtype=float)
    p_base = rng.dirichlet(alpha, size=N)    # Gaussian-shares baseline
    U = empirical_copula_cols(p_base)
    ui, uj = U[:, i], U[:, j]
    lam_base = np.array([np.mean(uj[ui > u] > u) if np.any(ui > u) else np.nan for u in u_grid])
    return u_grid, lam_base

def tail_dependence_on_shares_gauss_baseline(models, i=0, j=1, u_grid=None, N=200_000, seed=0):
    if u_grid is None:
        u_grid = np.linspace(0.50, 0.99, 25)

    P = next(iter(models.values())).shape[1]
    u_base, lam_base = _gaussian_shares_baseline_lambdaU(P, i=i, j=j, u_grid=u_grid, N=N, seed=seed)

    curves = {}
    ymax = np.nanmax(lam_base)
    for name, W in models.items():
        p = shares_from_W(W)           # your existing function
        U = empirical_copula_cols(p)
        ui, uj = U[:, i], U[:, j]
        lam = np.array([np.mean(uj[ui > u] > u) if np.any(ui > u) else np.nan for u in u_grid])
        curves[name] = lam
        ymax = max(ymax, np.nanmax(lam))

    fig, axes = plt.subplots(2, 2, figsize=(10, 9), sharex=True, sharey=True)
    axes = axes.ravel()
    for ax, (name, lam) in zip(axes, curves.items()):
        ax.plot(u_grid, lam, marker='o', ms=3, lw=1, label=name)
        ax.plot(u_base, lam_base, ls='--', lw=1.2, label='Gaussian baseline: Dir(½)')
        ax.set_title(name)
        ax.set_ylim(0, min(1.0, ymax * 1.05))
        ax.grid(True, linewidth=0.4, alpha=0.4)
        ax.set_xlabel("u (upper-tail threshold)")
        ax.set_ylabel(r"$\lambda_U(u) = P(U_j>u \mid U_i>u)$")
        ax.legend(fontsize=8, loc='upper right')

    fig.suptitle(r"Upper-tail dependence of $\frac{w^2}{\sum_k w_k^2}$ — (baseline = Gaussian ⇒ Dir(½))", y=0.98)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

# Example:
tail_dependence_on_shares_gauss_baseline(models, i=0, j=1, N=200_000, seed=42)


## TESTING STUFF

In [62]:
import numpy as np
import matplotlib.pyplot as plt

def winner_mask_col0(W, drop_ties=True):
    row_max = W.max(axis=1)
    argm = W.argmax(axis=1)
    if drop_ties:
        tie_counts = (W == row_max[:, None]).sum(axis=1)
        return (argm == 0) & (tie_counts == 1)
    return (argm == 0)

def transform_weights(W, mask=None, transform="none"):
    X = W.copy()
    if transform == "none":
        return X
    if transform == "zscore":
        mu = X.mean(axis=0, keepdims=True)
        sd = X.std(axis=0, keepdims=True) + 1e-12
        return (X - mu)/sd
    if transform == "mad":
        med = np.median(X, axis=0, keepdims=True)
        mad = np.median(np.abs(X - med), axis=0, keepdims=True) + 1e-12
        return (X - med)/(1.4826*mad)
    if transform in ("relative", "abs_relative"):
        if mask is None:
            raise ValueError("mask required for relative transforms")
        Xw = X[mask].copy()
        denom = np.abs(Xw[:, [0]]) + 1e-12
        if transform == "relative":
            Xw = Xw / denom
        else:
            Xw = np.abs(Xw) / denom
        # for non-winning rows, keep original scale (unused i hist-combo hvis du vil)
        Xn = X[~mask].copy()
        return Xw, Xn
    raise ValueError("unknown transform")

def plot_conditioned_hist_grid(models, bins=100, qlim=(0.01, 0.99)):
    """
    Plott conditioned histograms med kvantil-klipp på x-aksen.
    qlim = (lav, høy) kvantil som beholdes (f.eks. (0.01, 0.99)).
    """
    for name, W in models.items():
        mask = winner_mask_col0(W, drop_ties=True)
        W_win, W_non = W[mask], W[~mask]

        fig, axes = plt.subplots(3, 3, figsize=(12, 9))
        axes = axes.flatten()
        for j in range(1, 10):
            ax = axes[j-1]
            data = np.concatenate([W_win[:, j], W_non[:, j]])
            lo, hi = np.quantile(data, qlim)
            ax.hist(W_win[:, j], bins=bins, density=True, alpha=0.5,
                    label="winning", edgecolor='none')
            ax.hist(W_non[:, j], bins=bins, density=True, alpha=0.5,
                    label="non-winning", edgecolor='none')
            ax.set_xlim(lo, hi)
            ax.set_title(f"w{j}")
        fig.suptitle(f"{name} — conditioned distributions (clipped {qlim})", fontsize=14)
        axes[0].legend()
        plt.tight_layout()
        plt.show()


#plot_conditioned_hist_grid(models)

In [19]:
def plot_scatter_grid(models, s=6, alpha=0.2):
    for name, W in models.items():
        mask = winner_mask_col0(W, drop_ties=True)
        W_win = W[mask]
        if W_win.shape[0] == 0:
            continue

        fig, axes = plt.subplots(3, 3, figsize=(12, 9))
        axes = axes.flatten()
        for j in range(1, 10):
            ax = axes[j-1]
            ax.scatter(W_win[:, 0], W_win[:, j], s=s, alpha=alpha)
            ax.axhline(0, linewidth=0.8, color="k")
            ax.axvline(0, linewidth=0.8, color="k")
            ax.set_title(f"w0 vs w{j}")
        fig.suptitle(f"{name} — winning rows", fontsize=14)
        plt.tight_layout()
        plt.show()

# Example:
#plot_scatter_grid(models)


In [59]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

EPS = 1e-12

def stddev_from_draws(phi_draws, lambda_draws):
    """
    Inputs are draws for phi_tilde_data[j][i], lambda_tilde_data[j][i]
    shapes: (draws, J, I). Returns stddev with shape (draws, I, J)
    for the weights W_1[i,j] built as stddev * W1_raw[i,j].
    """
    std_j_i = np.sqrt(np.maximum(lambda_draws, 0.0)) * phi_draws
    std_j_i = np.maximum(std_j_i, EPS)
    return np.swapaxes(std_j_i, 1, 2)  # -> (draws, I, J)

phi_GAUSS = 1
lam_GAUSS = np.ones((8000, 10, 16))
std_GAUSS = stddev_from_draws(phi_GAUSS, lam_GAUSS)  # (draws, I, J)

phi_RHS = 1
lam_RHS = fits['Regularized Horseshoe']['Regularized Horseshoe']['posterior'].stan_variable("lambda_tilde")
std_RHS = stddev_from_draws(phi_RHS, lam_RHS)  # (draws, I, J)

# Example: Dirichlet-Horseshoe (adjust keys to your structure)
phi_DHS = fits['Dirichlet Horseshoe']['Dirichlet Horseshoe']['posterior'].stan_variable("phi_tilde_data")
lam_DHS = fits['Dirichlet Horseshoe']['Dirichlet Horseshoe']['posterior'].stan_variable("lambda_tilde_data")
std_DHS = stddev_from_draws(phi_DHS, lam_DHS)  # (draws, I, J)

phi_DST = fits['Dirichlet Student T']['Dirichlet Student T']['posterior'].stan_variable("phi_tilde_data")
lam_DST = fits['Dirichlet Student T']['Dirichlet Student T']['posterior'].stan_variable("lambda_tilde")
std_DST = stddev_from_draws(phi_DST, lam_DST)  # (draws, I, J)


models_std = {
    "Dir-HS": std_DHS,
    "Gaussian": std_GAUSS,
    "RHS": std_RHS,
    "Dir-ST": std_DST,
}


In [None]:
def ecdf(x):
    x = np.sort(x)
    n = x.size
    y = np.arange(1, n+1) / n
    return x, y

def plot_std_ecdf(models_std, log=True, qclip=(0.001, 0.999)):
    plt.figure(figsize=(7,5))
    for name, std in models_std.items():
        vals = std.reshape(-1)               # pool over (draws, I, J)
        vals = np.maximum(vals, EPS)
        if log:
            vals = np.log(vals)
        lo, hi = np.quantile(vals, qclip)
        vals = vals[(vals>=lo)&(vals<=hi)]
        x, y = ecdf(vals)
        plt.plot(x, y, label=name, linewidth=1.8)
    plt.xlabel("log stddev" if log else "stddev")
    plt.ylabel("ECDF")
    plt.title("Stddev distribution (pooled across weights)")
    plt.legend()
    plt.tight_layout()
    plt.show()

# Example:
plot_std_ecdf(models_std, log=True, qclip=(0.001, 0.999))


In [None]:
import scipy.stats as stats
res = stats.chatterjeexi(W_gauss[:, 0], W_gauss[:, 1])
res.statistic

In [None]:
res = stats.chatterjeexi(W_reg_hs[:, 0], W_reg_hs[:, 1])
res.statistic

In [None]:
res = stats.chatterjeexi(W_dir_hs[:, 0], W_dir_hs[:, 1])
res.statistic

In [None]:
res = stats.chatterjeexi(W_dir_st[:, 0], W_dir_st[:, 1])
res.statistic

In [None]:
from covercorr import coverage_correlation

kappa, pval = coverage_correlation(W_gauss[:, 0], W_gauss[:, 1])
print(f"Coverage correlation: {kappa}, p-value: {pval}")

In [None]:
from covercorr import coverage_correlation

kappa, pval = coverage_correlation(W_reg_hs[:, 0], W_reg_hs[:, 1])
print(f"Coverage correlation: {kappa}, p-value: {pval}")

In [None]:
from covercorr import coverage_correlation

kappa, pval = coverage_correlation(W_dir_hs[:, 0], W_dir_hs[:, 1])
print(f"Coverage correlation: {kappa}, p-value: {pval}")

In [None]:
from covercorr import coverage_correlation

kappa, pval = coverage_correlation(W_dir_st[:, 1], W_dir_st[:, 2])
print(f"Coverage correlation: {kappa}, p-value: {pval}")