In [1]:
import numpy as np

def make_latent_data_sec54(n, p, d=20, r_theta=1.0, sigma_xi=0.0, rng=None):
    """
    Section 5.4 latent model (Hastie–Montanari–Rosset–Tibshirani):
      X = Z W^T + U,   y = Z θ + ξ
      z_i ~ N(0, I_d), u_ij ~ N(0, 1), ξ_i ~ N(0, σ_ξ^2)
    Rows w_j of W satisfy ||w_j|| = 1.               [Fig. 5/6 setup]
    Population mapping to linear model:
      Σ = I_p + W W^T,   β = W (I + W^T W)^{-1} θ.   [eqs. (26)-(27)]
    Returns: X (n×p), y (n,), W (p×d), theta (d,), beta_true (p,), Sigma (p×p)
    """
    rng = np.random.default_rng() if rng is None else rng

    # Random W with unit-norm rows (||w_j||=1)
    W = rng.normal(size=(p, d))
    W /= np.linalg.norm(W, axis=1, keepdims=True) + 1e-12  # enforce ||w_j||=1

    # Latent Z, feature noise U, label noise ξ
    Z = rng.normal(size=(n, d))
    U = rng.normal(size=(n, p))
    xi = rng.normal(scale=sigma_xi, size=n)

    # Signal vector θ with ||θ|| = r_theta
    theta = rng.normal(size=d)
    theta *= r_theta / (np.linalg.norm(theta) + 1e-12)

    # Data
    X = Z @ W.T + U
    y = Z @ theta + xi

    # Population quantities for risk
    Sigma = np.eye(p) + W @ W.T
    beta_true = W @ np.linalg.solve(np.eye(d) + W.T @ W, theta)  # β = W (I + W^T W)^(-1) θ

    return X, y, W, theta, beta_true, Sigma


In [2]:
import numpy as np

def fit_min_norm(X, y):
    """
    Minimum-ℓ2-norm least squares: β̂ = X^+ y (ridgeless limit of ridge).
    """
    return np.linalg.pinv(X) @ y


In [24]:
import numpy as np

def sample_hidden_features_tanh_with_W1(
    X,
    rng,
    H,
    p_0=3,
    a=2.0,
    b=2.0,
    alpha_scale=0.5,
):
    """
    Returns (Z, W1) with Z = tanh(X @ W1), and W1 sampled from your prior.
    """
    n, P = X.shape
    tau0 = p_0 / (P - p_0)
    alpha = np.full(P, alpha_scale)

    tau = np.abs(rng.standard_cauchy()) * tau0
    c_sq = 1.0 / rng.gamma(shape=a, scale=1.0 / b, size=H)
    lambda_data = np.abs(rng.standard_cauchy(size=(H, P)))
    phi_data = rng.dirichlet(alpha, size=H)

    lam_sq = lambda_data**2
    denom = c_sq[:, None] + lam_sq * (tau**2)
    lambda_tilde = (c_sq[:, None] * lam_sq) / denom
    lambda_tilde = np.maximum(lambda_tilde, 1e-12)

    W1_raw = rng.normal(0.0, 1.0, size=(P, H))
    stddev = tau * np.sqrt(lambda_tilde.T) * np.sqrt(phi_data.T)  # (P,H)
    W1 = W1_raw * stddev

    Z = np.tanh(X @ W1)
    return Z, W1

def plot_risk_curve_hidden_units(
    n=400,
    gammas=(0.7, 0.9, 1.2, 1.5, 2, 3, 5, 8, 12, 20),
    d=20,
    r_theta=1.0,
    sigma_xi=0.0,
    reps=50,
    risk_mc_samples=1000,
    seed=0,
    # prior hyperparams
    p_0=3,
    a=2.0,
    b=2.0,
    alpha_scale=0.5,
    # NEW: schedule H as a function of (p, n)
    H_of_p=lambda p, n: p,   # <- default ties H to p (so H grows with γ)
):
    """
    Same as before, but H now depends on p (and n) via H_of_p.
    We reuse the same W1 for train and population risk.
    """
    rng = np.random.default_rng(seed)
    G, M, S = [], [], []

    for gamma in gammas:
        p = max(1, int(round(gamma * n)))
        H = max(1, int(H_of_p(p, n)))
        risks = []

        for _ in range(reps):
            # Your data generator
            X, y, W, theta, beta_true, Sigma = make_latent_data_sec54(
                n=n, p=p, d=d, r_theta=r_theta, sigma_xi=sigma_xi, rng=rng
            )

            # One sampled hidden map W1 for this rep; reuse it for population risk
            Z, W1 = sample_hidden_features_tanh_with_W1(
                X, rng, H=H, p_0=p_0, a=a, b=b, alpha_scale=alpha_scale
            )

            # Min-norm on hidden units
            w_hat = fit_min_norm(Z, y)
            if w_hat.ndim > 1 and w_hat.shape[1] == 1:
                w_hat = w_hat.ravel()

            # Monte Carlo population risk with the SAME W1
            X_pop = rng.multivariate_normal(mean=np.zeros(p), cov=Sigma, size=risk_mc_samples)
            Z_pop = np.tanh(X_pop @ W1)
            y_true_pop = X_pop @ beta_true
            y_pred_pop = Z_pop @ w_hat
            risks.append(float(np.mean((y_pred_pop - y_true_pop) ** 2)))

        G.append(gamma)
        M.append(np.mean(risks))
        S.append(np.std(risks, ddof=1))

    return np.array(G), np.array(M), np.array(S)

G_h, M_h, S_h = plot_risk_curve_hidden_units(
    n=100, d=20, r_theta=1.0, sigma_xi=0.0,
    gammas=[0.3, 0.7, 0.9, 1.5, 2, 3, 5, 8, 10],
    reps=50, risk_mc_samples=1000, seed=123,
    H_of_p=lambda p, n: p  # H follows p (thus follows γ)
)


In [None]:
import numpy as np
import matplotlib.pyplot as plt

reps = 50
ci_h = 1.96 * S_h / np.sqrt(reps)

plt.figure(figsize=(6.4, 4.4))
plt.plot(G_h, M_h, marker="o", linewidth=2)
plt.fill_between(G_h, M_h - ci_h, M_h + ci_h, alpha=0.2)
plt.axvline(1.0, linestyle="--", linewidth=1)
plt.xlabel(r"Aspect ratio  $\gamma = p/n$")
plt.ylabel(r"Population risk  $R_X=(\hat\beta-\beta)^\top \Sigma (\hat\beta-\beta)$")
plt.title(f"Latent space (§5.4) — min-norm risk vs γ  (n={400}, d={20}, r={1}, $\sigma_ξ$={0})")
plt.tight_layout()
plt.show()

In [None]:
M_h

In [None]:
M

In [29]:

def plot_risk_curve_sec54(
    n=400,
    gammas=(0.7, 0.9, 1.2, 1.5, 2, 3, 5, 8, 12, 20),
    d=20,
    r_theta=1.0,          # "r = 1" in the captions
    sigma_xi=0.0,         # use 0 for Fig. 5 behavior; try 0, 0.25, 0.5 like Fig. 6
    reps=50,
    seed=0
):
    """
    Replicates the latent-space risk curve of §5.4 (Figs. 5–6):
      For each γ = p/n, simulate (X,y), fit min-norm β̂, and compute
      population risk R_X = (β̂−β)^T Σ (β̂−β), then average across reps.
    Expectation from §5.4: spike near γ≈1 and then *monotone decrease* for γ>1,
    reaching a global minimum as γ→∞ when β aligns with top eigenspace of Σ. 
    """
    rng = np.random.default_rng(seed)
    G, M, S = [], [], []

    for gamma in gammas:
        p = max(1, int(round(gamma * n)))
        risks = []

        for _ in range(reps):
            X, y, W, theta, beta_true, Sigma = make_latent_data_sec54(
                n=n, p=p, d=d, r_theta=r_theta, sigma_xi=sigma_xi, rng=rng
            )
            beta_hat = fit_min_norm(X, y)

            diff = beta_hat - beta_true
            risks.append(float(diff @ (Sigma @ diff)))

        G.append(gamma)
        M.append(np.mean(risks))
        S.append(np.std(risks, ddof=1))

    G, M, S = np.array(G), np.array(M), np.array(S)
    ci = 1.96 * S / np.sqrt(reps)
    
    return G, M, S

# Example:
G, M, S = plot_risk_curve_sec54(n=100, d=20, r_theta=1.0, sigma_xi=0.0,
                    gammas=[0.3, 0.7, 0.9, 1.5, 2, 3, 5, 8, 10], reps=50)
# For a Fig. 6-style panel with noise: try sigma_xi in {0.0, 0.25, 0.5}.


In [None]:
reps = 50
ci = 1.96 * S / np.sqrt(reps)

plt.figure(figsize=(6.4, 4.4))
plt.plot(G, M, marker="o", linewidth=2)
plt.fill_between(G, M - ci, M + ci, alpha=0.2)
plt.axvline(1.0, linestyle="--", linewidth=1)
plt.xlabel(r"Aspect ratio  $\gamma = p/n$")
plt.ylabel(r"Population risk  $R_X=(\hat\beta-\beta)^\top \Sigma (\hat\beta-\beta)$")
plt.title(f"Latent space (§5.4) — min-norm risk vs γ  (n={400}, d={20}, r={1}, $\sigma_ξ$={0})")
plt.tight_layout()
plt.show()

In [None]:
import sys, os; sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), '..')))
#import os; os.chdir(os.path.dirname(os.getcwd()))
from utils.model_loader import get_model_fits

In [5]:
X, y, W, theta, beta_true, Sigma = make_latent_data_sec54(
n=400, p=2, d=20, r_theta=1, sigma_xi=0, rng=None
)

In [7]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

In [None]:
beta_hat = fit_min_norm(H_output, y_train)
diff = beta_hat - beta_true
float(diff @ (Sigma @ diff))

In [None]:
import argparse
from cmdstanpy import set_cmdstan_path

def run_regression_model_local(model_name, config_name, X_train, X_test, y_train, y_test, args):
    from cmdstanpy import CmdStanModel
    from utils.stan_data_generator import make_stan_data
    from utils.io_helpers import save_metadata
    import os, shutil
    import numpy as np
    
    # Set seed for reproducibility if provided
    seed = getattr(args, 'seed', None)
    if seed is not None:
        np.random.seed(seed)

    task = "prior"
    args.num_classes = 1  # Still needed

    stan_data = make_stan_data(model_name, task, X_train, y_train, X_test, args)

    model_path = f"bnn_prior_models_double_descent/{model_name}.stan"
    model = CmdStanModel(stan_file=model_path, force_compile=True)

    fit = model.sample(
        data=stan_data,
        chains=4,
        iter_sampling=args.samples,
        iter_warmup=args.burnin_samples,
        adapt_delta=0.8,
        parallel_chains=4,
        show_console=False,
        #max_treedepth = 12,
    )
    
    if args.data_config == "uci": 
        if args.standardize:
            output_dir = os.path.join(
            args.model_output_dir, "standardized"
        )
        else:
            output_dir = os.path.join(
                args.model_output_dir
            )
    else:
        output_dir = os.path.join(
            args.model_output_dir
        )
    os.makedirs(output_dir, exist_ok=True)
    save_metadata(output_dir, args, config_name)

    for i, path in enumerate(fit.runset.csv_files, start=1):
        shutil.copy(path, os.path.join(output_dir, f"chain_{i}.csv"))

    print(f"[✓] Saved results to: {output_dir}")


# Sett CmdStan-stien
set_cmdstan_path("/Users/augustarnstad/.cmdstan/cmdstan-2.36.0")
run_regression_model_local(
    model_name="gaussian",
    config_name="latent_model",
    X_train=X_train,
    X_test=X_test,
    y_train=y_train,
    y_test=y_test,
    args=argparse.Namespace(
        N=X_train.shape[0],
        p=X_train.shape[1],
        sigma=None,
        data="",
        standardize=False,
        test_shift=None,
        model="gaussian",
        H=800,
        L=1,
        config="Latent",
        seed=1,
        data_config="realworld",
        model_output_dir="results/ridgeless/double_descent/priors/gaussian_tanh",
        burnin_samples=1,
        samples=500,
    )
)

In [None]:
results_dir_tanh = "results/ridgeless/double_descent/priors"
model_names_tanh = ["Gaussian tanh"]#, "Regularized Horseshoe tanh", "Dirichlet Horseshoe tanh", "Dirichlet Student T tanh"]

tanh_fit = get_model_fits(
    config="",
    results_dir=results_dir_tanh,
    models=model_names_tanh,
    include_prior=False,
)


In [None]:
post_gauss  = tanh_fit['Gaussian tanh']['posterior']
# post_RHS = tanh_fit['Regularized Horseshoe tanh']['posterior']
# post_DHS = tanh_fit['Dirichlet Horseshoe tanh']['posterior']
# post_DST = tanh_fit['Dirichlet Student T tanh']['posterior']

w1 = post_gauss.stan_variable("W_1")
b1 = post_gauss.stan_variable("hidden_bias")

W1= w1[0]
B1=b1[0]

In [None]:
import numpy as np

# hidden layer pre-activation
H_train = X_train @ W1.T + B1        # shape (n_train, m)
H_test  = X_test  @ W1.T + B1        # shape (n_test, m)

# apply tanh nonlinearity
H_train = np.tanh(H_train)
H_test  = np.tanh(H_test)

# optional stabilization: scale by sqrt(m) so width is comparable
m = W1.shape[0]
H_train = H_train / np.sqrt(m)
H_test  = H_test / np.sqrt(m)


In [None]:
# ridgeless: a_hat = H_train^+ y_train (minimum-norm readout)
a_hat = np.linalg.pinv(H_train) @ y_train     # shape (m,)


In [None]:
y_pred_test = H_test @ a_hat
test_mse = np.mean((y_pred_test - y_test)**2)

print("Test MSE for THIS ONE SAMPLE:", test_mse)


In [None]:
import numpy as np

# Activation (match your BNN; choose 'relu' or 'tanh')
def activation(U, kind="relu"):
    if kind == "relu":
        return np.maximum(0.0, U)
    elif kind == "tanh":
        return np.tanh(U)
    else:
        raise ValueError("Activation must be 'relu' or 'tanh'.")

# Build hidden features for a given (W1, b1) sample and a *chosen* width m_use (<= full width m_full)
def build_features(Z, W1, b1, m_use=None, act="relu", standardize=True):
    """
    Z: (n,d) latent inputs from §5.4
    W1: (m_full, d)   b1: (m_full,)
    m_use: number of hidden units to keep (sweeps γ = m_use / n)
    Returns H: (n, m_use) design matrix for readout.
    """
    m_full = W1.shape[0]
    m = m_full if m_use is None else int(m_use)
    assert 1 <= m <= m_full
    # Use first m units (you can also random-subselect or permute for robustness)
    W1m = W1[:m, :]
    b1m = b1[:m]
    U = Z @ W1m.T + b1m  # pre-activation
    H = activation(U, act)
    # Scale features to be width-invariant and comparable across priors
    # 1) column standardize (unit variance), 2) 1/sqrt(m) scaling
    if standardize:
        std = H.std(axis=0, ddof=1)
        std[std < 1e-8] = 1.0
        H = (H - H.mean(axis=0)) / std
    H = H / np.sqrt(m)
    return H


In [None]:
import numpy as np

def min_norm_fit(H, y):
    """Ridgeless least squares readout for hidden features: a_hat = H^+ y."""
    return np.linalg.pinv(H) @ y

def test_mse_for_sample(W1, b1, Z_tr, y_tr, Z_te, y_te, m_use, act="relu"):
    """
    Given one prior sample (W1,b1), build H_tr/H_te with m_use units, fit min-norm readout,
    and return test MSE on fresh latent inputs Z_te (same §5.4 labeling rule).
    """
    H_tr = build_features(Z_tr, W1, b1, m_use=m_use, act=act, standardize=True)
    a_hat = min_norm_fit(H_tr, y_tr)

    H_te = build_features(Z_te, W1, b1, m_use=m_use, act=act, standardize=True)
    y_pred = H_te @ a_hat
    return float(np.mean((y_pred - y_te)**2))


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def prior_predictive_risk_curve(
    W1_list, b1_list,                     # your S samples, all with the same full width m_full and input dim d
    n_train=400, n_test=5000, d=20,
    widths=None,                          # list of m_use values; by default a sweep around interpolation
    act="relu", sigma_xi=0.0, r_theta=1.0,
    reps_data=3,                          # average over a few train/test redraws for stability
    seed=0
):
    """
    For each width m_use (γ = m_use / n_train), average test MSE over:
      - reps_data redraws of (Z_tr, y_tr, Z_te, y_te) per §5.4
      - all S prior samples (W1,b1)
    Returns (gammas, mean_risk, std_risk)
    """
    rng = np.random.default_rng(seed)
    S = len(W1_list)
    m_full = W1_list[0].shape[0]

    if widths is None:
        # Sweep below/near/above interpolation threshold γ≈1
        widths = [max(1, int(w)) for w in (0.5*n_train, 0.8*n_train, 0.95*n_train, n_train,
                                           1.1*n_train, 1.5*n_train, 2*n_train, 3*n_train, 5*n_train)]
        widths = [min(int(m_full), w) for w in widths]
        widths = sorted(set(widths))

    gammas = [w / n_train for w in widths]
    mean_risks, std_risks = [], []

    for m_use in widths:
        risks = []
        for _ in range(reps_data):
            # Fresh §5.4 data each rep (same for all S samples)
            Z_tr, y_tr, theta = make_sec54_data(n_train, d=d, r_theta=r_theta, sigma_xi=sigma_xi, rng=rng)
            Z_te, y_te, _     = make_sec54_data(n_test,  d=d, r_theta=r_theta, sigma_xi=sigma_xi, rng=rng)

            # Average across prior samples
            for W1, b1 in zip(W1_list, b1_list):
                risks.append(test_mse_for_sample(W1, b1, Z_tr, y_tr, Z_te, y_te, m_use=m_use, act=act))

        mean_risks.append(float(np.mean(risks)))
        std_risks.append(float(np.std(risks, ddof=1)))

    # Plot
    ci = 1.96 * (np.array(std_risks) / np.sqrt(S * reps_data))
    plt.figure(figsize=(6.6, 4.6))
    plt.plot(gammas, mean_risks, marker="o", linewidth=2)
    plt.fill_between(gammas, np.array(mean_risks) - ci, np.array(mean_risks) + ci, alpha=0.2)
    plt.axvline(1.0, linestyle="--", linewidth=1)
    plt.xlabel(r"Aspect ratio  $\gamma = m / n_{\rm train}$  (width / sample size)")
    plt.ylabel("Test MSE on §5.4 latent task")
    plt.title(f"Prior-predictive descent profile (S={S}, reps={reps_data}, act={act})")
    plt.tight_layout()
    plt.show()

    return np.array(gammas), np.array(mean_risks), np.array(std_risks)

# --- Example usage (one prior family) ---
# Suppose you already have S samples: W1_list = [W1_s], b1_list = [b1_s]
# Each W1_s: shape (m_full, d), each b1_s: shape (m_full,)
# gammas, mean_risk, std_risk = prior_predictive_risk_curve(W1_list, b1_list, d=W1_list[0].shape[1],
#                                                           n_train=400, n_test=5000, act="relu",
#                                                           sigma_xi=0.0, r_theta=1.0, reps_data=3, seed=123)
