In [2]:
# ==== Quick AMP smoke test over several random datasets (no MPI) ====
import numpy as np
from time import perf_counter

import sys
import os

# Add parent folder (one level up from notebook) to Python path
sys.path.append(os.path.abspath(".."))

from amp_experiment import generate_data
from amp import AMP_algo
from state_evolution import generate_latents

import matplotlib.pyplot as plt
from tqdm import tqdm



In [3]:
def run_many_amp(
    runs=3, alpha = 10, d=300,
    beta_u=1.0, beta_v=2.0, gamma=1.0, delta=0.0,
    plant=0.9, damp=0.2, tol=1e-4, max_iter=200, eps=1e-6, seed0=0
):
    results = []
    n = int(alpha*d)
    for r in range(runs):
        np.random.seed(seed0 + r)
        # --- data ---
        X, y, u_star, v_star = generate_data(n, d, beta_u, beta_v, gamma=gamma, delta=delta)
        w_star = np.stack([u_star, v_star], axis=1)  # (d, 2)

        # Assume you already have X, y, u_star, v_star
        X_pos = X[y == 1]
        y_pos = y[y == 1]

        ## make this simpler, remove y


        # --- run AMP ---
        t0 = perf_counter()
        w, overlap_hist, converged = AMP_algo(
            X_pos, y_pos, K=2,
            beta_u=beta_u, beta_v=beta_v, gamma=gamma, delta=delta,
            max_iter=max_iter, tol=tol, plant=plant, damp=damp,
            w_star=w_star, eps=eps, rank=None  # no MPI
        )
        t1 = perf_counter()

        # final overlap matrix (KxK)
        Q_final = overlap_hist[-1]
        print("Q")
        print(Q_final)

        print(f"[run {r}] converged={converged} | steps={len(overlap_hist)-1} | time={t1-t0:.2f}s")

        results.append({
            "run": r,
            "converged": converged,
            "steps": len(overlap_hist)-1,
            "time_s": t1 - t0,
            "w": w,
            "overlap_hist": overlap_hist,  # shape: T x K x K
        })
    return results

# ---- go! tweak sizes to taste ----
res = run_many_amp(
    runs=5, alpha=30, d=500,
    beta_u=1.0, beta_v=2.0, gamma=1.0, delta=0.0,
    plant=0.9, damp=0.7, tol=1e-4, max_iter=1000, eps=1e-6, seed0=120
)

Iteration 0


  return 1/np.sqrt(a)*(np.exp(-c)*np.exp((d-b)**2/(2*a))*erfc(-(d-b)/np.sqrt(2*a))/2 + np.exp(c)*np.exp((b+d)**2/(2*a))*erfc(-(b+d)/np.sqrt(2*a))/2)
  return 1/np.sqrt(a)*(np.exp(-c)*np.exp((d-b)**2/(2*a))*erfc(-(d-b)/np.sqrt(2*a))/2 + np.exp(c)*np.exp((b+d)**2/(2*a))*erfc(-(b+d)/np.sqrt(2*a))/2)
  return 1/np.sqrt(a)*(np.exp(-c)*np.exp((d-b)**2/(2*a))*erfc(-(d-b)/np.sqrt(2*a))/2 + np.exp(c)*np.exp((b+d)**2/(2*a))*erfc(-(b+d)/np.sqrt(2*a))/2)
  return 1/np.sqrt(a)*(np.exp(-c)*np.exp((d-b)**2/(2*a))*erfc(-(d-b)/np.sqrt(2*a))/2 + np.exp(c)*np.exp((b+d)**2/(2*a))*erfc(-(b+d)/np.sqrt(2*a))/2)
  return 1/(np.sqrt(a)*a)*(np.exp(c)*np.exp(((b+d)**2)/(2*a))*(d+b)*erfc(-(d+b)/np.sqrt(2*a))/2-np.exp(-c)*np.exp((d-b)**2/(2*a))*erfc(-(d-b)/np.sqrt(2*a))/2*(d-b)) + 1/(np.sqrt(2*np.pi)*a)*(np.exp(c)-np.exp(-c))
  return 1/(np.sqrt(a)*a)*(np.exp(c)*np.exp(((b+d)**2)/(2*a))*(d+b)*erfc(-(d+b)/np.sqrt(2*a))/2-np.exp(-c)*np.exp((d-b)**2/(2*a))*erfc(-(d-b)/np.sqrt(2*a))/2*(d-b)) + 1/(np.sqrt(2*np.pi)*a)*(

NaN encountered at iter 4
Q
[[nan nan]
 [nan nan]]
[run 0] converged=False | steps=5 | time=0.74s
Iteration 0
Iteration 20
Iteration 40
Iteration 60
Iteration 80
Converged at iter 95
Q
[[-0.79622218 -0.0755183 ]
 [-0.07691849 -0.86256078]]
[run 1] converged=True | steps=96 | time=14.26s
Iteration 0
Iteration 20
Iteration 40
Iteration 60
Converged at iter 76
Q
[[0.97975187 0.04434556]
 [0.021866   0.90237103]]
[run 2] converged=True | steps=77 | time=10.92s
Iteration 0
NaN encountered at iter 16
Q
[[nan nan]
 [nan nan]]
[run 3] converged=False | steps=17 | time=2.30s
Iteration 0
NaN encountered at iter 4
Q
[[nan nan]
 [nan nan]]
[run 4] converged=False | steps=5 | time=0.65s


In [None]:
# --- choose sweep settings ---
d = 500
alpha_list =  np.linspace(0.1, 10, 30)
runs_per_alpha = 10
plant = 0.7

# --- storage ---
alpha_vals = []
q00_mean = []
q00_std = []

for i, alpha in tqdm(enumerate(alpha_list)):
    res = run_many_amp(
        runs=runs_per_alpha,
        alpha=alpha,
        d=d,
        beta_u=1.0,
        beta_v=0.0,
        gamma=0.0,
        delta=0.0,
        plant=plant,
        damp=0.5,
        tol=1e-4,
        max_iter=500,
        eps=1e-6,
        seed0=123 + 1000*i  # change seeds across alpha values
    )

    # collect final Q[0,0] from converged runs
    q00s = []
    for r in res:
        if r["converged"] and len(r["overlap_hist"]) > 0:
            Q_final = r["overlap_hist"][-1]  # shape (K, K)
            q00s.append(Q_final[0, 0])

    if len(q00s) == 0:
        alpha_vals.append(alpha)
        q00_mean.append(np.nan)
        q00_std.append(np.nan)
    else:
        alpha_vals.append(alpha)
        q00_mean.append(float(np.mean(q00s)))
        q00_std.append(float(np.std(q00s)))

# --- plot ---
plt.figure(figsize=(6,4))
plt.errorbar(alpha_vals, q00_mean, yerr=q00_std, fmt='o-', capsize=3)
plt.xlabel(r'$\alpha = n/d$')
plt.ylabel(r'Final overlap $Q_{00}$')
plt.title(r'AMP: $Q_{00}$ vs $\alpha$  ($\beta_u=1,\ \beta_v=0$)')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
