In [None]:
# ==== Quick AMP smoke test over several random datasets (no MPI) ====
import numpy as np
from time import perf_counter

import sys
import os

# Add parent folder (one level up from notebook) to Python path
sys.path.append(os.path.abspath(".."))

from valorisation_clean.ae_state_evolution import beta_tilde, F_1, F_2, F_2_batch, find_proximal_relu, hat_vars_MCMC, run_one_alpha, state_func, generate_latents

In [2]:
def test_F1_F2_batch():
    beta = (0.5, 0.7)
    beta_u, beta_v = beta
    bt_v = beta_tilde(beta_v)
    F1 = F_1(bt_v)
    assert F1.shape == (2, 2)
    assert F1[0, 0] == 0
    assert F1[1, 1] == -bt_v

    n = 10
    lam, nu = generate_latents(n, gamma=0.1, delta=0.2)
    F2 = F_2_batch(beta, lam, nu, np.eye(2), bt_v)
    assert F2.shape == (n, 2)
    print("F_1 and F_2_batch OK")

test_F1_F2_batch()


F_1 and F_2_batch OK


In [3]:
def brute_force_objective(h, omega, q, V):
    return -h*np.maximum(0, h) + 0.5*np.maximum(0, h)**2*q + 0.5*(h-omega)**2/V

def test_find_proximal_relu():
    np.random.seed(0)
    samples = 5
    omega = np.linspace(-3, 3, samples)[:, None]
    q = 1.3
    V = 0.7
    prox = find_proximal_relu(omega, q, V)
    assert prox.shape == omega.shape

    # brute-force over the three candidates used inside
    prox_neg  = omega
    prox_zero = np.zeros_like(omega)
    prox_pos  = omega/(1+V*(2*q-4))

    val_neg  = brute_force_objective(prox_neg,  omega, q, V)
    val_zero = brute_force_objective(prox_zero, omega, q, V)
    val_pos  = brute_force_objective(prox_pos,  omega, q, V)
    vals = np.concatenate([val_neg, val_zero, val_pos], axis=1)
    idx_bruteforce = np.argmin(vals, axis=1)
    candidates = np.concatenate([prox_neg, prox_zero, prox_pos], axis=1)
    prox_bf = candidates[np.arange(samples), idx_bruteforce][:, None]

    assert np.allclose(prox, prox_bf)
    print("find_proximal_relu OK; prox =", prox.ravel())

test_find_proximal_relu()


find_proximal_relu OK; prox = [-3.  -1.5  0.   1.5  3. ]


In [4]:
def test_hat_vars_MCMC_shapes():
    np.random.seed(0)
    alpha = 1.2
    beta = (0.5, 0.7)
    gamma, delta = 0.1, 0.2
    samples = 2000

    m0 = np.array([[0.3, -0.4]])  # shape (1,2)
    q0 = 1.0
    V0 = 1.0
    q_vars = (m0, q0, V0)

    m_hat, q_hat, V_hat = hat_vars_MCMC(alpha, beta, q_vars, samples, gamma, delta)

    print("m_hat:", m_hat, "shape:", m_hat.shape)
    print("q_hat:", q_hat, "type:", type(q_hat))
    print("V_hat:", V_hat, "type:", type(V_hat))

    assert m_hat.shape == (1, 2)
    assert np.isfinite(m_hat).all()
    assert np.isscalar(q_hat)
    assert np.isscalar(V_hat)
    assert np.isfinite(q_hat)
    assert np.isfinite(V_hat)

    print("hat_vars_MCMC shapes & finiteness OK")

test_hat_vars_MCMC_shapes()


m_hat: [[0. 0.]] shape: (1, 2)
q_hat: 0.562644499015001 type: <class 'numpy.float64'>
V_hat: -0.562644499015001 type: <class 'numpy.float64'>
hat_vars_MCMC shapes & finiteness OK


In [5]:
def test_run_one_alpha_single_rank():
    np.random.seed(0)
    alpha = 1.2
    beta = (0.5, 0.7)
    gamma, delta = 0.1, 0.2
    samples = 1000
    iters = 5

    m0 = np.array([[0.1, 0.0]])
    q0 = 1.0
    V0 = 1.0
    q_init = (m0, q0, V0)

    state_traj, hat_traj = run_one_alpha(
        alpha=alpha,
        beta=beta,
        gamma=gamma,
        delta=delta,
        q_init=q_init,
        samples=samples,
        iters=iters,
        lam=1e-6,
        damping=0.7,
        print_every=1
    )

    # because you did np.array() on a list of tuples, this will be an object array
    assert len(state_traj) == iters
    assert len(hat_traj) == iters

    m_last, q_last, V_last = state_traj[-1]
    m_hat_last, q_hat_last, V_hat_last = hat_traj[-1]

    print("Last state m,q,V:", m_last, q_last, V_last)
    print("Last hats m_hat,q_hat,V_hat:", m_hat_last, q_hat_last, V_hat_last)

    assert m_last.shape == (1, 2)
    assert np.isfinite(m_last).all()
    assert np.isscalar(q_last)
    assert np.isscalar(V_last)
    assert np.isfinite(q_last)
    assert np.isfinite(V_last)

    assert m_hat_last.shape == (1, 2)
    assert np.isfinite(m_hat_last).all()
    assert np.isscalar(q_hat_last)
    assert np.isscalar(V_hat_last)
    assert np.isfinite(q_hat_last)
    assert np.isfinite(V_hat_last)

    print("run_one_alpha single-rank OK")

test_run_one_alpha_single_rank()


[alpha 1.200 | iter 0] m, q, V =
(array([[0.1, 0. ]]), 1.0, 1.0)
[alpha 1.200 | iter 0] m_hat, q_hat, V_hat (local) =
(array([[0., 0.]]), 0.5524037451595363, -0.5524037451595363)
[alpha 1.200 | iter 1] m, q, V =
(array([[0.03, 0.  ]]), 1.5671936794760537, -0.9671913855131857)
[alpha 1.200 | iter 1] m_hat, q_hat, V_hat (local) =
(array([[ 0.01480542, -0.053959  ]]), 2.0482254590627798, -0.1258407520418819)
[alpha 1.200 | iter 2] m, q, V =
(array([[-0.07335707,  0.30015397]]), 91.14846815144614, -5.852787575375078)
[alpha 1.200 | iter 2] m_hat, q_hat, V_hat (local) =
(array([[-0.04493912, -0.00654442]]), 3.2609203095566413, -0.08091070319308964)
[alpha 1.200 | iter 3] m, q, V =
(array([[0.36678905, 0.14666599]]), 376.25307303448955, -10.407456194261002)
[alpha 1.200 | iter 3] m_hat, q_hat, V_hat (local) =
(array([[-4.52488105e-05,  9.26952707e-02]]), 4.204983374541392, 0.5670764903872968)
[alpha 1.200 | iter 4] m, q, V =
(array([[0.10998086, 0.15842277]]), 122.04792162595675, -1.88783765