In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
from utils.model_loader import get_model_fits
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
data_dir = f"datasets/friedman_correlated"
results_dir_relu_correlated = "results/regression/single_layer/relu/friedman_correlated/no_lambda"
results_dir_tanh_correlated = "results/regression/single_layer/tanh/friedman_correlated/no_lambda"

model_names_relu = ["Dirichlet tau", "Beta tau", "Dirichlet", "Beta"]
model_names_tanh = ["Dirichlet tau tanh", "Beta tau tanh", "Dirichlet tanh", "Beta tanh"]

relu_fits_correlated = {}
tanh_fits_correlated = {}
files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npz"))
for fname in files:
    base_config_name = fname.replace(".npz", "")  # e.g., "GAM_N100_p8_sigma1.00_seed1"
    full_config_path = f"{base_config_name}"  # → "type_1/GAM_N100_p8_sigma1.00_seed1"
    relu_fit_correlated = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_relu_correlated,
        models=model_names_relu,
        include_prior=False,
    )
    
    tanh_fit_correlated = get_model_fits(
        config=full_config_path,
        results_dir=results_dir_tanh_correlated,
        models=model_names_tanh,
        include_prior=False,
    )
    

    relu_fits_correlated[base_config_name] = relu_fit_correlated  # use clean key
    tanh_fits_correlated[base_config_name] = tanh_fit_correlated  # use clean key
    


In [26]:
import re
import numpy as np
import pandas as pd
from properscoring import crps_ensemble
from scores.probability import crps_for_ensemble

def compute_rmse_from_fits(all_fits, model_names=None, folder="friedman"):
    """
    Iterate over all dataset keys in `all_fits` (e.g., relu_fits or tanh_fits).
    For each model in `model_names` (or all models found if None), compute:
      - RMSE for each posterior draw
      - RMSE of the posterior mean predictor

    Returns:
        df_rmse: long DF with one row per posterior draw.
        df_posterior_rmse: one row per model/dataset with posterior-mean RMSE.
    """
    rmse_rows = []
    post_mean_rows = []

    for dataset_key, model_dict in all_fits.items():
        N, sigma, seed = extract_friedman_metadata(dataset_key)
        if N is None:
            # Skip non-Friedman entries if any
            continue
        
        try:
            path = f"datasets/{folder}/Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}.npz"
            data = np.load(path)
            y_test = data["y_test"].squeeze()  # shape (N_test,)
        except FileNotFoundError:
            path = f"datasets/{folder}/many/Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}.npz"
            data = np.load(path)
            y_test = data["y_test"].squeeze()  # shape (N_test,)
            #print(f"[SKIP] y_test not found: {path}")
            #continue

        # Choose which models to evaluate
        models_to_eval = model_names or list(model_dict.keys())

        for model in models_to_eval:
            # Some entries may be missing
            entry = model_dict.get(model, None)
            if not entry or "posterior" not in entry:
                print(f"[SKIP] Missing posterior: {dataset_key} -> {model}")
                continue

            fit = entry["posterior"]

            # Expecting (S, N_test, 1) or (S, N_test)
            output_test = fit.stan_variable("output_test")
            if output_test.ndim == 3 and output_test.shape[-1] == 1:
                preds = output_test[..., 0]  # (S, N_test)
            elif output_test.ndim == 2:
                preds = output_test  # (S, N_test)
            else:
                raise ValueError(f"Unexpected output_test shape {output_test.shape} for {dataset_key} -> {model}")

            # Per-sample RMSE
            sq_err = (preds - y_test[None, :])**2  # (S, N_test)
            rmse_per_sample = np.sqrt(np.mean(sq_err, axis=1))  # (S,)

            for s_idx, rmse in enumerate(rmse_per_sample):
                rmse_rows.append({
                    "dataset_key": dataset_key,
                    "model": model,
                    "N": N,
                    "sigma": sigma,
                    "seed": seed,
                    "sample_idx": s_idx,
                    "rmse": float(rmse)
                })

            # Posterior-mean RMSE
            posterior_mean = preds.mean(axis=0)  # (N_test,)
            post_mean_rmse = float(np.sqrt(np.mean((posterior_mean - y_test)**2)))
            post_mean_rows.append({
                "dataset_key": dataset_key,
                "model": model,
                "N": N,
                "sigma": sigma,
                "seed": seed,
                "posterior_mean_rmse": post_mean_rmse
            })

    df_rmse = pd.DataFrame(rmse_rows)
    df_posterior_rmse = pd.DataFrame(post_mean_rows)
    return df_rmse, df_posterior_rmse

_FRIEDMAN_KEY = re.compile(r"Friedman_N(\d+)_p\d+_sigma([\d.]+)_seed(\d+)")

def extract_friedman_metadata(key: str):
    """
    Parse 'Friedman_N{N}_p10_sigma{sigma}_seed{seed}' -> (N:int, sigma:float, seed:int)
    Returns (None, None, None) if it doesn't match.
    """
    m = _FRIEDMAN_KEY.search(key)
    if not m:
        return None, None, None
    N = int(m.group(1))
    sigma = float(m.group(2))
    seed = int(m.group(3))
    return N, sigma, seed


df_rmse_relu_correlated, df_posterior_rmse_relu_correlated = compute_rmse_from_fits(
    relu_fits_correlated, model_names_relu, folder = "friedman_correlated"
)


df_rmse_tanh_correlated, df_posterior_rmse_tanh_correlated = compute_rmse_from_fits(
    tanh_fits_correlated, model_names_tanh, folder = "friedman_correlated"
)


In [29]:
import pandas as pd

df3 = df_rmse_relu_correlated.assign(activation="ReLU", setting="Correlated")
df4 = df_rmse_tanh_correlated.assign(activation="Tanh", setting="Correlated")

df_all = pd.concat([df3, df4], ignore_index=True)

df3_pm = df_posterior_rmse_relu_correlated.assign(activation="ReLU", setting="Correlated")
df4_pm = df_posterior_rmse_tanh_correlated.assign(activation="Tanh", setting="Correlated")

df_all_pm = pd.concat([df3_pm, df4_pm], ignore_index=True)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# --- prepare data ---
df = df_all.copy()

abbr = {
    "Dirichlet tau": r"Dir - $\tau$",
    "Beta tau": r"Beta - $\tau$",
    "Dirichlet": r"Dir",
    "Beta": r"Beta",
}

# unify model names across activations (strip " tanh")
df["model_clean"] = df["model"].str.replace(" tanh", "", regex=False)

# summary stats per (setting, N, model, activation)
summary = (
    df.groupby(["setting", "N", "model_clean", "activation"], as_index=False)["rmse"]
      .agg(mean="mean", std="std")
)

# plotting order
settings = ["Original", "Correlated"]
Ns = [50, 100, 200, 500]
models = ["Dirichlet tau", "Beta tau", "Dirichlet", "Beta"]

# visuals
markers = {"ReLU": "o", "Tanh": "^"}            # shapes
offsets = {"ReLU": -0.12, "Tanh": +0.12}        # side-by-side jitter on x
model_offsets = {
    "Dirichlet tau": -0.04,
    "Beta tau": -0.01,
    "Dirichlet": +0.01,
    "Beta": +0.04,
}
palette_list = plt.get_cmap("tab10").colors
palette = {m: palette_list[i] for i, m in enumerate(models)}

# map N to base x positions and add offsets for activation
xbase = {N: i for i, N in enumerate(Ns)}

fig, axes = plt.subplots(1, 2, figsize=(12, 7), sharey=True)

for ax, setting in zip(axes, settings):
    sub = summary[summary["setting"] == setting]
    # plot each model+activation with errorbars, without lines
    for m in models:
        for act in ["ReLU", "Tanh"]:
            g = sub[(sub["model_clean"] == m) & (sub["activation"] == act)]
            if g.empty:
                continue
            #xs = [xbase[n] + offsets[act] for n in g["N"]]
            xs = [xbase[n] + offsets[act] + model_offsets[m] for n in g["N"]]

            ax.errorbar(
                xs, g["mean"], yerr=g["std"],
                fmt=markers[act], markersize=10,
                linestyle="none", capsize=3,
                color=palette[m], markeredgecolor="black"
            )

    ax.set_title(f"{setting}", fontsize=15)
    ax.set_xticks(range(len(Ns)))
    ax.set_xticklabels(Ns, fontsize=15)
    ax.set_xlabel("N", fontsize=15)
    ax.set_ylabel("RMSE", fontsize=15)
    ax.tick_params(axis='y', labelsize=15)
    ax.grid()

# --- legends ---
model_handles = [
    Line2D(
        [0], [0],
        marker="o",
        linestyle="none",
        color=palette[m],
        markeredgecolor="black",
        markersize=12,
        label=abbr.get(m, m)   # <- use abbreviation
    )
    for m in models
]

# activation legend (shapes)
activation_handles = [
    Line2D([0], [0], marker=markers["ReLU"], linestyle="none", color="black",
           markersize=12, label="ReLU"),
    Line2D([0], [0], marker=markers["Tanh"], linestyle="none", color="black",
           markersize=12, label="Tanh"),
]

for ax in axes:
    ax.legend(
        handles=model_handles + activation_handles,
        title=None,
        loc="upper right",
        frameon=False,
        ncol=1,
        fontsize = 14
    )
plt.tight_layout(rect=(0, 0, 1, 1))
#plt.grid()
plt.savefig("figures_for_use_in_paper/friedman_RMSE_with_beta.pdf", bbox_inches="tight")
plt.show()

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
# --- prepare data ---
df = df_all_pm.copy()

# unify model names across activations (strip " tanh")
df["model_clean"] = df["model"].str.replace(" tanh", "", regex=False)

# summary stats per (setting, N, model, activation)
summary = (
    df.groupby(["setting", "N", "model_clean", "activation"], as_index=False)["posterior_mean_rmse"]
      .agg(mean="mean", std="std")
)
# plotting order
settings = ["Original", "Correlated"]
Ns = [50, 100, 200, 500]
models = ["Dirichlet tau", "Beta tau", "Dirichlet", "Beta"]

# visuals
markers = {"ReLU": "o", "Tanh": "^"}            # shapes
offsets = {"ReLU": -0.12, "Tanh": +0.12}        # side-by-side jitter on x
model_offsets = {
    "Dirichlet tau": -0.04,
    "Beta tau": -0.01,
    "Dirichlet": +0.01,
    "Beta": +0.04,
}
palette_list = plt.get_cmap("tab10").colors
palette = {m: palette_list[i] for i, m in enumerate(models)}

# map N to base x positions and add offsets for activation
xbase = {N: i for i, N in enumerate(Ns)}

fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

for ax, setting in zip(axes, settings):
    sub = summary[summary["setting"] == setting]
    # plot each model+activation with errorbars, without lines
    for m in models:
        for act in ["ReLU", "Tanh"]:
            g = sub[(sub["model_clean"] == m) & (sub["activation"] == act)]
            if g.empty:
                continue
            #xs = [xbase[n] + offsets[act] for n in g["N"]]
            xs = [xbase[n] + offsets[act] + model_offsets[m] for n in g["N"]]
            
            ax.plot(
                xs, g["mean"],
                marker=markers[act],
                markersize=7,
                linestyle="none",
                color=palette[m],
                markeredgecolor="black",
            )


    ax.set_title(f"{setting}")
    ax.set_xticks(range(len(Ns)))
    ax.set_xticklabels(Ns)
    ax.set_xlabel("N")
    ax.set_ylabel("RMSE")
    ax.grid()

# --- legends ---
model_handles = [
    Line2D(
        [0], [0],
        marker="o",
        linestyle="none",
        color=palette[m],
        markeredgecolor="black",
        markersize=7,
        label=abbr.get(m, m)   # <- use abbreviation
    )
    for m in models
]

# activation legend (shapes)
activation_handles = [
    Line2D([0], [0], marker=markers["ReLU"], linestyle="none", color="black",
           markersize=7, label="ReLU"),
    Line2D([0], [0], marker=markers["Tanh"], linestyle="none", color="black",
           markersize=7, label="Tanh"),
]

for ax in axes:
    ax.legend(
        handles=model_handles + activation_handles,
        title=None,
        loc="upper right",
        frameon=False,
        ncol=1
    )
plt.tight_layout(rect=(0, 0, 1, 1))
#plt.grid()
plt.show()

In [43]:
import numpy as np
import pandas as pd
from utils.sparsity import forward_pass_relu, forward_pass_tanh, local_prune_weights

def compute_sparse_rmse_results_corr(seeds, models, all_fits, get_N_sigma, forward_pass,
                         sparsity=0.0, prune_fn=None):
    results = []
    posterior_means = []

    for seed in seeds:
        N, sigma = get_N_sigma(seed)
        dataset_key = f'Friedman_N{N}_p10_sigma{sigma:.2f}_seed{seed}'
        path = f"datasets/friedman_correlated/{dataset_key}.npz"

        try:
            data = np.load(path)
            X_test, y_test = data["X_test"], data["y_test"]
        except FileNotFoundError:
            print(f"[SKIP] File not found: {path}")
            continue

        for model in models:
            try:
                fit = all_fits[dataset_key][model]['posterior']
                W1_samples = fit.stan_variable("W_1")           # (S, P, H)
                W2_samples = fit.stan_variable("W_L")           # (S, H, O)
                b1_samples = fit.stan_variable("hidden_bias")   # (S, O, H)
                b2_samples = fit.stan_variable("output_bias")   # (S, O)
            except KeyError:
                print(f"[SKIP] Model or posterior not found: {dataset_key} -> {model}")
                continue

            S = W1_samples.shape[0]
            rmses = np.zeros(S)
            #print(y_test.shape)
            y_hats = np.zeros((S, y_test.shape[0]))

            for i in range(S):
                W1 = W1_samples[i]
                W2 = W2_samples[i]

                # Apply pruning mask if requested
                if prune_fn is not None and sparsity > 0.0:
                    masks = prune_fn([W1, W2], sparsity)
                    W1 = W1 * masks[0]
                    #W2 = W2 * masks[1]

                y_hat = forward_pass(X_test, W1, b1_samples[i][0], W2, b2_samples[i])
                y_hats[i] = y_hat.squeeze()  # Store the prediction for each sample
                rmses[i] = np.sqrt(np.mean((y_hat.squeeze() - y_test)**2))
                
            posterior_mean = np.mean(y_hats, axis=0)
            posterior_mean_rmse = np.sqrt(np.mean((posterior_mean - y_test.squeeze())**2))

            posterior_means.append({
                'seed': seed,
                'N': N,
                'sigma': sigma,
                'model': model,
                'sparsity': sparsity,
                'posterior_mean_rmse': posterior_mean_rmse
            })

            for i in range(S):
                results.append({
                    'seed': seed,
                    'N': N,
                    'sigma': sigma,
                    'model': model,
                    'sparsity': sparsity,
                    'rmse': rmses[i]
                })

    df_rmse = pd.DataFrame(results)
    df_posterior_rmse = pd.DataFrame(posterior_means)

    return df_rmse, df_posterior_rmse


seeds = [1] #[1, 2, 11]
seeds_correlated = [1]#, 6, 11]

def get_N_sigma(seed):
    if seed == 1:
        N=100
    elif seed == 2:
        N=200
    else:
        N=500
    sigma=1.00
    return N, sigma

def get_N_sigma_correlated(seed):
    if seed == 1:
        N=100
    elif seed == 6:
        N=200
    else:
        N=500
    sigma=1.00
    return N, sigma

In [34]:
sparsity_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]

df_rmse_relu_correlated, df_posterior_rmse_relu_correlated = {}, {}
df_rmse_tanh_correlated, df_posterior_rmse_tanh_correlated = {}, {}

for sparsity in sparsity_levels:
    df_rmse_relu_correlated[sparsity], df_posterior_rmse_relu_correlated[sparsity]= compute_sparse_rmse_results_corr(
        seeds_correlated, model_names_relu, relu_fits_correlated, get_N_sigma_correlated, forward_pass_relu,
        sparsity=sparsity, prune_fn=local_prune_weights
    )
        
    df_rmse_tanh_correlated[sparsity], df_posterior_rmse_tanh_correlated[sparsity]= compute_sparse_rmse_results_corr(
        seeds_correlated, model_names_tanh, tanh_fits_correlated, get_N_sigma_correlated, forward_pass_tanh,
        sparsity=sparsity, prune_fn=local_prune_weights
    )

In [35]:
import pandas as pd


df_rmse_full_relu_correlated = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_relu_correlated.items()],
    ignore_index=True
)


df_rmse_full_tanh_correlated = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_rmse_tanh_correlated.items()],
    ignore_index=True
)


df_posterior_rmse_full_relu_correlated = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_posterior_rmse_relu_correlated.items()],
    ignore_index=True
)

df_posterior_rmse_full_tanh_correlated = pd.concat(
    [df.assign(sparsity=sparsity) for sparsity, df in df_posterior_rmse_tanh_correlated.items()],
    ignore_index=True
)


In [None]:
df_posterior_rmse_full_tanh_correlated[df_posterior_rmse_full_tanh_correlated['sparsity']==0.9]

In [None]:
df_posterior_rmse_full_relu_correlated[df_posterior_rmse_full_relu_correlated['sparsity']==0.9]

In [44]:
# =========================
# BLOCK 1 — FUNCTIONS ONLY
# =========================
import numpy as np

def forward_pass_relu(X, W1, b1, W2, b2):
    pre = X @ W1 + b1.reshape(1, -1)
    hid = np.maximum(0, pre)
    out = hid @ W2 + b2.reshape(1, -1)
    return out

def forward_pass_tanh(X, W1, b1, W2, b2):
    pre = X @ W1 + b1.reshape(1, -1)
    hid = np.tanh(pre)
    out = hid @ W2 + b2.reshape(1, -1)
    return out

def mse_loss_and_grads(X, y, W1, b1, W2, b2, activation="relu"):
    """
    MSE regression loss + exact backprop grads for 1-hidden-layer NN.
    Shapes assumed:
      X:  (N, D)
      y:  (N, K)
      W1: (D, H), b1: (H,)
      W2: (H, K), b2: (K,)
    """
    pre = X @ W1 + b1.reshape(1, -1)

    if activation == "relu":
        hid = np.maximum(0, pre)
        dhid_dpre = (pre > 0).astype(pre.dtype)
    else:  # tanh
        hid = np.tanh(pre)
        dhid_dpre = 1.0 - hid**2

    pred = hid @ W2 + b2.reshape(1, -1)

    r = pred - y
    loss = np.mean(r**2)

    N = X.shape[0]
    d_pred = (2.0 / N) * r          # (N, K)

    dW2 = hid.T @ d_pred            # (H, K)
    db2 = np.sum(d_pred, axis=0)    # (K,)

    d_hid = d_pred @ W2.T           # (N, H)
    d_pre = d_hid * dhid_dpre       # (N, H)

    dW1 = X.T @ d_pre               # (D, H)
    db1 = np.sum(d_pre, axis=0)     # (H,)

    return loss, {"dW1": dW1, "db1": db1, "dW2": dW2, "db2": db2}

def pack_grads(grads):
    return np.concatenate([
        grads["dW1"].ravel(),
        grads["db1"].ravel(),
        grads["dW2"].ravel(),
        grads["db2"].ravel(),
    ])

def pack_params(W1, b1, W2, b2):
    return np.concatenate([W1.ravel(), b1.ravel(), W2.ravel(), b2.ravel()])

def unpack_params(theta, D, H, K):
    i = 0
    W1 = theta[i:i + D*H].reshape(D, H); i += D*H
    b1 = theta[i:i + H];                i += H
    W2 = theta[i:i + H*K].reshape(H, K); i += H*K
    b2 = theta[i:i + K]
    return W1, b1, W2, b2

def gradient_cosine_similarities(G):
    meanG = G.mean(axis=0)
    mean_norm = np.linalg.norm(meanG) + 1e-12
    G_norms = np.linalg.norm(G, axis=1) + 1e-12
    return (G @ meanG) / (G_norms * mean_norm)

def gradient_snr_over_batches(grad_list):
    """
    grad_list: list of flattened gradients across batches at fixed parameters
    """
    G = np.stack(grad_list, axis=0)
    mu = G.mean(axis=0)
    num = np.linalg.norm(mu)
    denom = np.sqrt(np.mean(np.sum((G - mu)**2, axis=1))) + 1e-12
    return num / denom

def directional_grad_lipschitz_fd(X, y, W1, b1, W2, b2, activation="relu", eps=1e-3, rng=None):
    """
    Approximates ||grad(w+eps*v) - grad(w)|| / eps along random unit direction v.
    Cheap local curvature proxy.
    """
    if rng is None:
        rng = np.random.default_rng(0)

    D, H = W1.shape
    K = W2.shape[1]
    theta = pack_params(W1, b1, W2, b2)

    v = rng.normal(size=theta.shape)
    v /= (np.linalg.norm(v) + 1e-12)

    _, g0 = mse_loss_and_grads(X, y, W1, b1, W2, b2, activation=activation)
    g0 = pack_grads(g0)

    theta2 = theta + eps * v
    W1_2, b1_2, W2_2, b2_2 = unpack_params(theta2, D, H, K)
    _, g1 = mse_loss_and_grads(X, y, W1_2, b1_2, W2_2, b2_2, activation=activation)
    g1 = pack_grads(g1)

    return np.linalg.norm(g1 - g0) / eps

def make_minibatches(X, y, batch_size=64, n_batches=10, seed=0):
    """
    Minimal minibatch sampler for SNR section.
    Returns list of (Xb, yb) with yb already shaped (B, K).
    """
    rng = np.random.default_rng(seed)
    N = X.shape[0]
    batches = []
    for _ in range(n_batches):
        ids = rng.choice(N, size=batch_size, replace=False)
        batches.append((X[ids], y[ids]))
    return batches


In [None]:
# =========================
# BLOCK 2 — EXECUTION ONLY
# =========================

# --- A) DATA LOADING + BATCH SETUP ---
dataset_key = f'Friedman_N{500}_p10_sigma{1:.2f}_seed{11}'
path = f"datasets/friedman/{dataset_key}.npz"
data = np.load(path)

X_train, X_test = data["X_train"], data["X_test"]
y_train, y_test = data["y_train"], data["y_test"]

# Pick batch (test set)
X_batch = X_test
y_batch = y_test.reshape(-1, 1)  # enforce (N,1) ONCE

# --- B) MODEL + POSTERIOR DRAWS ---
model_name = "Regularized Horseshoe tanh"
activation = "tanh"   # "relu" or "tanh"

post = posterior_N100_fits[model_name]["posterior"]
W1_draws = post.stan_variable("W_1")
b1_draws = post.stan_variable("hidden_bias")
W2_draws = post.stan_variable("W_L")
b2_draws = post.stan_variable("output_bias")

S = W1_draws.shape[0]
D, H = W1_draws.shape[1], W1_draws.shape[2]
K = W2_draws.shape[2]

# --- C) SUBSAMPLE DRAWS FOR DIAGNOSTICS ---
S_use = min(S, 300)
idx = np.arange(S_use)

# --- D) POSTERIOR GRADIENT STABILITY (across draws) ---
G = []
losses = []
for s in idx:
    loss_s, grads_s = mse_loss_and_grads(
        X_batch, y_batch,
        W1_draws[s], b1_draws[s],
        W2_draws[s], b2_draws[s],
        activation=activation
    )
    losses.append(loss_s)
    G.append(pack_grads(grads_s))

G = np.stack(G, axis=0)
g_norms = np.linalg.norm(G, axis=1)
g_cos = gradient_cosine_similarities(G)

print("Posterior gradient norms (5/50/95%):", np.quantile(g_norms, [0.05, 0.5, 0.95]))
print("Posterior grad cosine vs mean (5/50/95%):", np.quantile(g_cos,   [0.05, 0.5, 0.95]))
print("Posterior losses (5/50/95%):",             np.quantile(np.array(losses), [0.05, 0.5, 0.95]))

# --- E) CURVATURE PROXY (directional FD gradient Lipschitz) ---
rng = np.random.default_rng(1)
S_curv = min(S_use, 100)
Ldir = []
for s in idx[:S_curv]:
    Ldir.append(directional_grad_lipschitz_fd(
        X_batch, y_batch,
        W1_draws[s], b1_draws[s],
        W2_draws[s], b2_draws[s],
        activation=activation,
        eps=1e-3,
        rng=rng
    ))
Ldir = np.array(Ldir)
print("Directional grad-Lipschitz proxy (5/50/95%):", np.quantile(Ldir, [0.05, 0.5, 0.95]))

# --- F) (OPTIONAL) GRADIENT SNR ACROSS MINIBATCHES FOR ONE DRAW ---
DO_SNR = True  # set False to skip this section cleanly

if DO_SNR:
    # build batches from train set; ensure y is (B,1) in each batch
    y_train_2d = y_train.reshape(-1, 1)
    batches = make_minibatches(X_train, y_train_2d, batch_size=64, n_batches=10, seed=0)

    s0 = 0
    grad_list = []
    for Xb, yb in batches:
        _, grads_b = mse_loss_and_grads(
            Xb, yb,
            W1_draws[s0], b1_draws[s0],
            W2_draws[s0], b2_draws[s0],
            activation=activation
        )
        grad_list.append(pack_grads(grads_b))

    snr = gradient_snr_over_batches(grad_list)
    print("Gradient SNR across minibatches at draw s0:", snr)


In [None]:
# =========================
# BLOCK 2 — EXECUTION ONLY
# =========================

# --- A) DATA LOADING + BATCH SETUP ---
dataset_key = f'Friedman_N{500}_p10_sigma{1:.2f}_seed{11}'
path = f"datasets/friedman/{dataset_key}.npz"
data = np.load(path)

X_train, X_test = data["X_train"], data["X_test"]
y_train, y_test = data["y_train"], data["y_test"]

# Pick batch (test set)
X_batch = X_test
y_batch = y_test.reshape(-1, 1)  # enforce (N,1) ONCE

# --- B) MODEL + POSTERIOR DRAWS ---
model_name = "Dirichlet Horseshoe tanh"
activation = "tanh"   # "relu" or "tanh"

post = posterior_N100_fits[model_name]["posterior"]
W1_draws = post.stan_variable("W_1")
b1_draws = post.stan_variable("hidden_bias")
W2_draws = post.stan_variable("W_L")
b2_draws = post.stan_variable("output_bias")

S = W1_draws.shape[0]
D, H = W1_draws.shape[1], W1_draws.shape[2]
K = W2_draws.shape[2]

# --- C) SUBSAMPLE DRAWS FOR DIAGNOSTICS ---
S_use = min(S, 300)
idx = np.arange(S_use)

# --- D) POSTERIOR GRADIENT STABILITY (across draws) ---
G = []
losses = []
for s in idx:
    loss_s, grads_s = mse_loss_and_grads(
        X_batch, y_batch,
        W1_draws[s], b1_draws[s],
        W2_draws[s], b2_draws[s],
        activation=activation
    )
    losses.append(loss_s)
    G.append(pack_grads(grads_s))

G = np.stack(G, axis=0)
g_norms = np.linalg.norm(G, axis=1)
g_cos = gradient_cosine_similarities(G)

print("Posterior gradient norms (5/50/95%):", np.quantile(g_norms, [0.05, 0.5, 0.95]))
print("Posterior grad cosine vs mean (5/50/95%):", np.quantile(g_cos,   [0.05, 0.5, 0.95]))
print("Posterior losses (5/50/95%):",             np.quantile(np.array(losses), [0.05, 0.5, 0.95]))

# --- E) CURVATURE PROXY (directional FD gradient Lipschitz) ---
rng = np.random.default_rng(1)
S_curv = min(S_use, 100)
Ldir = []
for s in idx[:S_curv]:
    Ldir.append(directional_grad_lipschitz_fd(
        X_batch, y_batch,
        W1_draws[s], b1_draws[s],
        W2_draws[s], b2_draws[s],
        activation=activation,
        eps=1e-3,
        rng=rng
    ))
Ldir = np.array(Ldir)
print("Directional grad-Lipschitz proxy (5/50/95%):", np.quantile(Ldir, [0.05, 0.5, 0.95]))

# --- F) (OPTIONAL) GRADIENT SNR ACROSS MINIBATCHES FOR ONE DRAW ---
DO_SNR = True  # set False to skip this section cleanly

if DO_SNR:
    # build batches from train set; ensure y is (B,1) in each batch
    y_train_2d = y_train.reshape(-1, 1)
    batches = make_minibatches(X_train, y_train_2d, batch_size=64, n_batches=10, seed=0)

    s0 = 0
    grad_list = []
    for Xb, yb in batches:
        _, grads_b = mse_loss_and_grads(
            Xb, yb,
            W1_draws[s0], b1_draws[s0],
            W2_draws[s0], b2_draws[s0],
            activation=activation
        )
        grad_list.append(pack_grads(grads_b))

    snr = gradient_snr_over_batches(grad_list)
    print("Gradient SNR across minibatches at draw s0:", snr)


In [None]:
# =========================
# BLOCK 2 — EXECUTION ONLY
# =========================

# --- A) DATA LOADING + BATCH SETUP ---
dataset_key = f'Friedman_N{500}_p10_sigma{1:.2f}_seed{11}'
path = f"datasets/friedman/{dataset_key}.npz"
data = np.load(path)

X_train, X_test = data["X_train"], data["X_test"]
y_train, y_test = data["y_train"], data["y_test"]

# Pick batch (test set)
X_batch = X_test
y_batch = y_test.reshape(-1, 1)  # enforce (N,1) ONCE

# --- B) MODEL + POSTERIOR DRAWS ---
model_name = "Beta Horseshoe tanh"
activation = "tanh"   # "relu" or "tanh"

post = posterior_N100_fits[model_name]["posterior"]
W1_draws = post.stan_variable("W_1")
b1_draws = post.stan_variable("hidden_bias")
W2_draws = post.stan_variable("W_L")
b2_draws = post.stan_variable("output_bias")

S = W1_draws.shape[0]
D, H = W1_draws.shape[1], W1_draws.shape[2]
K = W2_draws.shape[2]

# --- C) SUBSAMPLE DRAWS FOR DIAGNOSTICS ---
S_use = min(S, 300)
idx = np.arange(S_use)

# --- D) POSTERIOR GRADIENT STABILITY (across draws) ---
G = []
losses = []
for s in idx:
    loss_s, grads_s = mse_loss_and_grads(
        X_batch, y_batch,
        W1_draws[s], b1_draws[s],
        W2_draws[s], b2_draws[s],
        activation=activation
    )
    losses.append(loss_s)
    G.append(pack_grads(grads_s))

G = np.stack(G, axis=0)
g_norms = np.linalg.norm(G, axis=1)
g_cos = gradient_cosine_similarities(G)

print("Posterior gradient norms (5/50/95%):", np.quantile(g_norms, [0.05, 0.5, 0.95]))
print("Posterior grad cosine vs mean (5/50/95%):", np.quantile(g_cos,   [0.05, 0.5, 0.95]))
print("Posterior losses (5/50/95%):",             np.quantile(np.array(losses), [0.05, 0.5, 0.95]))

# --- E) CURVATURE PROXY (directional FD gradient Lipschitz) ---
rng = np.random.default_rng(1)
S_curv = min(S_use, 100)
Ldir = []
for s in idx[:S_curv]:
    Ldir.append(directional_grad_lipschitz_fd(
        X_batch, y_batch,
        W1_draws[s], b1_draws[s],
        W2_draws[s], b2_draws[s],
        activation=activation,
        eps=1e-3,
        rng=rng
    ))
Ldir = np.array(Ldir)
print("Directional grad-Lipschitz proxy (5/50/95%):", np.quantile(Ldir, [0.05, 0.5, 0.95]))

# --- F) (OPTIONAL) GRADIENT SNR ACROSS MINIBATCHES FOR ONE DRAW ---
DO_SNR = True  # set False to skip this section cleanly

if DO_SNR:
    # build batches from train set; ensure y is (B,1) in each batch
    y_train_2d = y_train.reshape(-1, 1)
    batches = make_minibatches(X_train, y_train_2d, batch_size=64, n_batches=10, seed=0)

    s0 = 0
    grad_list = []
    for Xb, yb in batches:
        _, grads_b = mse_loss_and_grads(
            Xb, yb,
            W1_draws[s0], b1_draws[s0],
            W2_draws[s0], b2_draws[s0],
            activation=activation
        )
        grad_list.append(pack_grads(grads_b))

    snr = gradient_snr_over_batches(grad_list)
    print("Gradient SNR across minibatches at draw s0:", snr)


## RHS
Posterior gradient norms (5/50/95%): [0.16470399 0.25879643 0.51681143]

Posterior grad cosine vs mean (5/50/95%): [-0.00202803  0.53864411  0.72677177]

Posterior losses (5/50/95%): [0.03460988 0.04023591 0.04680309]

Directional grad-Lipschitz proxy (5/50/95%): [ 1.53809331  6.23787591 18.88394891]

Gradient SNR across minibatches at draw s0: 0.2733637598423519

## Dirichlet
Posterior gradient norms (5/50/95%): [0.16137586 0.24824739 0.47044927]

Posterior grad cosine vs mean (5/50/95%): [0.05011153 0.43383093 0.65442166]

Posterior losses (5/50/95%): [0.03467333 0.03932857 0.04483365]

Directional grad-Lipschitz proxy (5/50/95%): [ 1.07342317  6.41491771 16.73721671]

Gradient SNR across minibatches at draw s0: 0.36852927989323614

## BETA
Posterior gradient norms (5/50/95%): [0.17097207 0.25707368 0.49875474]

Posterior grad cosine vs mean (5/50/95%): [-0.07632456  0.56714935  0.7694549 ]

Posterior losses (5/50/95%): [0.03445648 0.03909852 0.04472353]

Directional grad-Lipschitz proxy (5/50/95%): [ 1.52483842  6.88174365 21.6660652 ]

Gradient SNR across minibatches at draw s0: 0.40063152311073363


In [None]:
#