# ML Project 
## continuous neurones

----------------------

 ### **First task:**
- #### *find s+ such that the convergence time T(N) became fast*

In [24]:
import math
from dataclasses import dataclass
from typing import Tuple, List, Dict, Optional

import jax
import jax.numpy as jnp
import equinox as eqx
from jax import random, lax
import matplotlib.pyplot as plt

##### **Model definition**

In [25]:
class ContRNN(eqx.Module):
    J: jnp.ndarray  # (N, N)
    g: float
    s: float
    dt: float = 0.5

    def step(self, x: jnp.ndarray) -> jnp.ndarray:
        y = jnp.tanh(x)
        dx = -x + self.s * y + self.g * (self.J @ y)
        return x + self.dt * dx
# this function define the iteractive system

def init_J(key: jax.Array, N: int) -> jnp.ndarray:
    J = random.normal(key, shape=(N, N)) / jnp.sqrt(N)
    J = J.at[jnp.arange(N), jnp.arange(N)].set(0.0)
    return J
#this function randomly initialize J with Jii=0

def make_seed(g: float, s: float, j: int, N: int) -> int:
    a = int(round(1000 * g))
    b = int(round(1000 * s))
    seed = (a * 0x9E3779B1 + b * 0x85EBCA77 + j * 0xC2B2AE3D + N * 0x27D4EB2F) & 0xFFFFFFFF
    return jnp.uint32(seed)
#this function make a key per each combination

##### **Single run simulation**

In [26]:
@dataclass
class RunConfig:
    dt: float = 0.5
    eps: float = 1e-5
    tmax_steps: int = 10000

def median_time_over_J_and_restarts(
    master_key: jax.Array, N: int, g: float, s: float, cfg: RunConfig, M: int = 3, R: int = 7
) -> Tuple[float, float]:
    """return the median of the time on M different matrix J and R restart for each one."""
    Ts_all = []
    for m in range(M):
        keyJ = random.fold_in(master_key, m)  # new matrix J
        keys = random.split(random.fold_in(keyJ, 12345), R)  # R static initially
        run_v = jax.vmap(lambda k: one_run(k, N, g, s, cfg))
        Ts, flags = run_v(keys)
        Ts_all.append(jnp.where(flags, Ts, jnp.nan))  # take just the convergent ones

    Ts_all = jnp.concatenate(Ts_all, axis=0)
    median_T = jnp.nanmedian(Ts_all)
    conv_rate = jnp.isfinite(Ts_all).mean()
    return float(median_T), float(conv_rate)

def one_run(key: jax.Array, N: int, g: float, s: float, cfg: RunConfig) -> Tuple[jnp.ndarray, jnp.ndarray]:
    kJ, kx = random.split(key, 2)
    J = init_J(kJ, N)
    model = ContRNN(J=J, g=g, s=s, dt=cfg.dt)
    x = random.normal(kx, shape=(N,))
#definition of the model and variables

    def body_fun(carry):
        x, prev_x, t = carry
        x_next = model.step(x)
        return (x_next, x, t + 1)
#the update 

    def cond_fun(carry):
        x, prev_x, t = carry
        dx = jnp.linalg.norm(x - prev_x)
        tol = cfg.eps * (1.0 + jnp.linalg.norm(x))
        not_conv = dx > tol
        not_timeout = t < cfg.tmax_steps
        return jnp.logical_and(not_conv, not_timeout)
#stops when time is not converging

    init_prev = x + 10.0
    x, prev_x, steps = lax.while_loop(cond_fun, body_fun, (x, init_prev, 0))

    dx_final = jnp.linalg.norm(x - prev_x)
    tol_final = cfg.eps * (1.0 + jnp.linalg.norm(x))
    converged = dx_final <= tol_final

    T = cfg.dt * jnp.asarray(steps, dtype=jnp.float32)
    T = jnp.where(converged, T, cfg.dt * cfg.tmax_steps)
    return T, converged


##### **Vectorized runs**

In [27]:
def median_time_over_restarts(
    key: jax.Array, N: int, g: float, s: float, cfg: RunConfig, R: int = 5
) -> Tuple[float, float, float, float]:
    keys = random.split(key, R)
    run_v = jax.vmap(lambda k: one_run(k, N, g, s, cfg))
    Ts, flags = run_v(keys)

    conv_mask = flags.astype(bool)
    conv_Ts = jnp.where(conv_mask, Ts, jnp.nan)

    def robust_quantile(arr, q):
        arr_sorted = jnp.sort(jnp.where(jnp.isnan(arr), jnp.inf, arr))
        valid = jnp.isfinite(arr_sorted)
        num = jnp.sum(valid)

        def _fallback():
            return cfg.dt * cfg.tmax_steps

        def _quant():
            idx = jnp.maximum(0, jnp.minimum(num - 1, jnp.floor((num - 1) * q).astype(int)))
            return arr_sorted[idx]

        return lax.cond(num == 0, _fallback, _quant)

    median_T = robust_quantile(conv_Ts, 0.5)
    q25 = robust_quantile(conv_Ts, 0.25)
    q75 = robust_quantile(conv_Ts, 0.75)
    conv_rate = jnp.mean(conv_mask.astype(jnp.float32))
    return float(median_T), float(q25), float(q75), float(conv_rate)

##### **Fit utilities**<br> i.e. find when logT(N) fit best N^r or a*log(N)

In [28]:
@dataclass
class FitResult:
    kind: str
    bic_poly: float
    bic_exp: float
    a_exp: float
    k_poly: float
    n_used: int


def _linfit(x: jnp.ndarray, y: jnp.ndarray) -> Tuple[float, float, float]:
    x_mean = jnp.mean(x)
    y_mean = jnp.mean(y)
    x_var = jnp.mean((x - x_mean) ** 2)
    a = jnp.where(x_var > 0, jnp.mean((x - x_mean) * (y - y_mean)) / x_var, 0.0)
    b = y_mean - a * x_mean
    yhat = a * x + b
    sse = jnp.sum((y - yhat) ** 2)
    return float(a), float(b), float(sse)


def _bic(n: int, sse: float, k_params: int = 2) -> float:
    if n <= k_params or sse <= 0:
        return jnp.inf
    return float(n * math.log(sse / n) + k_params * math.log(n))


def pick_model(Ns, Ts, tmax, *, bic_margin=5.0, censor_thresh=0.3) -> FitResult:
    mask = Ts < tmax - 1e-8
    n_all = int(Ns.shape[0])
    Ns_used, Ts_used = Ns[mask], Ts[mask]
    n = int(Ns_used.shape[0])

    # molta censura → nonconv
    if 1.0 - (n / max(n_all,1)) >= censor_thresh:
        return FitResult("nonconv", math.inf, math.inf, 0.0, 0.0, n)

    if n < 3:
        # con pochi punti, se n>=2 forza la scelta col BIC (niente margine)
        if n >= 2:
            logT, logN = jnp.log(Ts_used), jnp.log(Ns_used)
            a_exp, _, sse_exp = _linfit(Ns_used, logT); bic_exp = _bic(n, sse_exp, 2)
            k_poly, _, sse_poly = _linfit(logN, logT); bic_poly = _bic(n, sse_poly, 2)
            kind = "exp" if bic_exp < bic_poly else "poly"
            return FitResult(kind, float(bic_poly), float(bic_exp), float(a_exp), float(k_poly), n)
        return FitResult("unknown", math.inf, math.inf, 0.0, 0.0, n)

    logT, logN = jnp.log(Ts_used), jnp.log(Ns_used)
    a_exp, _, sse_exp = _linfit(Ns_used, logT); bic_exp = _bic(n, sse_exp, 2)
    k_poly, _, sse_poly = _linfit(logN, logT); bic_poly = _bic(n, sse_poly, 2)

    if bic_poly + bic_margin < bic_exp: kind = "poly"
    elif bic_exp + bic_margin < bic_poly: kind = "exp"
    else:
        # tie-break: comunque scegli il migliore
        kind = "exp" if bic_exp < bic_poly else "poly"

    return FitResult(kind, float(bic_poly), float(bic_exp), float(a_exp), float(k_poly), n)


##### **cache + binary search**

In [29]:
from typing import Any, Tuple, Dict

# cache: s -> (kind, FitResult|None)
Cache = Dict[float, Tuple[str, Optional[FitResult]]]

def compute_fit_for_s_with_cache(
    master_key, g: float, s: float, N_list, cfg,
    *, M=3, R=7, conv_thresh=0.6, bic_margin=5.0, cache: Cache
) -> Tuple[str, Optional[FitResult]]:
    """Ritorna (kind, fit) usando/aggiornando la cache. Valutazione parametrica (M,R,bic_margin)."""
    # cache key include anche M,R,bic_margin per sicurezza?
    # Qui teniamo la cache solo per s (come prima) per semplicità;
    # sotto, quando servono valutazioni robuste, sovrascriviamo semplicemente la voce di cache.
    if s in cache:
        return cache[s]

    Ts, Ns, convs = [], [], []
    for N in N_list:
        med, conv_rate = median_time_over_J_and_restarts(master_key, N, g, s, cfg, M=M, R=R)
        Ts.append(med); Ns.append(N); convs.append(conv_rate)

    Ns = jnp.asarray(Ns); Ts = jnp.asarray(Ts); convs = jnp.asarray(convs)
    KEEP = convs >= conv_thresh
    Ns_k, Ts_k = Ns[KEEP], Ts[KEEP]

    if (KEEP.mean() < 0.3) or ((~KEEP[-2:]).all()) or (Ns_k.size < 3):
        cache[s] = ("nonconv", None)
        return cache[s]

    fit = pick_model(Ns_k, Ts_k, cfg.dt * cfg.tmax_steps, bic_margin=bic_margin)
    cache[s] = (fit.kind, fit)
    return cache[s]



def bracket_from_cache(cache: Cache, s_grid) -> Tuple[Optional[float], Optional[float]]:
    """Trova ultimo s con exp e primo s con poly già calcolati nel coarse scan."""
    s_exp_last = None
    s_poly_first = None
    for s in sorted(s_grid):
        kind, _ = cache.get(s, ("unknown", None))
        if kind == "exp":
            s_exp_last = s
        if (s_poly_first is None) and (kind == "poly"):
            s_poly_first = s
    return s_exp_last, s_poly_first

def first_poly_with_previous(cache: Cache, s_grid):
    """
    Trova il primo s della griglia con kind='poly' e ritorna (s_prev, s_poly).
    Se il primo poly è all'indice 0 o non esiste alcun poly, ritorna (None, None).
    """
    s_sorted = sorted(s_grid)
    for i, s in enumerate(s_sorted):
        kind, _ = cache.get(s, ("unknown", None))
        if kind == "poly":
            if i == 0:
                return None, None
            return s_sorted[i-1], s_sorted[i]
    return None, None

def binary_search_s_plus(
    master_key, g: float, N_list, cfg,
    lo: float, hi: float,
    *, cache: Cache, M=5, R=11, conv_thresh=0.6, bic_margin=7.5, s_tol=1e-2, max_steps=12
) -> float:
    """Ricerca binaria del confine dove compaiono i poly. Usa settaggi più robusti."""
    for _ in range(max_steps):
        if hi - lo < s_tol:
            break
        mid = 0.5 * (lo + hi)

        # ricalcola mid in modo robusto (M,R più alti, bic_margin più severo)
        cache.pop(mid, None)
        kind_mid, _ = compute_fit_for_s_with_cache(
            master_key, g, mid, N_list, cfg,
            M=M, R=R, conv_thresh=conv_thresh, bic_margin=bic_margin, cache=cache
        )

        if kind_mid == "poly":
            hi = mid
        elif kind_mid in ("exp", "nonconv"):
            lo = mid
        else:  # 'unknown'
            shrink = 0.2 * (hi - lo)
            lo = lo + 0.1 * shrink
            hi = hi - 0.1 * shrink

    return 0.5 * (lo + hi)



##### **helpers**

In [30]:
def first_poly_after_nonpoly(cache: Cache, s_grid):
    """
    Trova il PRIMO s della griglia per cui kind=='poly' e il precedente valido
    ha kind in {'exp','nonconv'}.
    Ritorna (s_prev, s_poly) oppure (None, None) se non esiste.
    """
    s_sorted = sorted(s_grid)

    def kind_of(s):
        k, _ = cache.get(s, ("unknown", None))
        return k  # può essere 'exp','poly','nonconv','unknown'

    prev_s = None
    prev_k = None
    for s in s_sorted:
        k = kind_of(s)
        if k == "poly" and prev_k in ("exp", "nonconv"):
            return prev_s, s
        # aggiorna solo se il punto precedente è “informativo”
        if k in ("exp", "poly", "nonconv"):
            prev_s, prev_k = s, k

    return None, None

def reinforce_kind_at_points(master_key, g, points, N_list, cfg, cache: Cache,
                             M_strong=5, R_strong=11, bic_margin_strong=7.5, conv_thresh=0.6):
    """Ricalcola (e sovrascrive in cache) il kind in modo più robusto per un set di s."""
    for s in points:
        # sovrascrive con una valutazione più 'forte'
        cache.pop(s, None)
        compute_fit_for_s_with_cache(
            master_key, g, s, N_list, cfg,
            M=M_strong, R=R_strong, conv_thresh=conv_thresh,
            bic_margin=bic_margin_strong, cache=cache
        )



##### **Main experiment with plotting**

In [31]:
def main():
    g_list = [0.25 * k for k in range(0, 9)]            # 0.00 .. 2.00
    s_grid = [0.25 * k for k in range(0, 19)]           # 0.00 .. 4.50
    N_list = [20, 30, 45, 68, 100, 150, 225, 338, 500]
    cfg = RunConfig(dt=0.5, eps=1e-5, tmax_steps=10000)
    master_key = random.PRNGKey(0)

    # parametri: coarse (veloce) e robust (solo su confini)
    COARSE = dict(M=3, R=7, conv_thresh=0.6, bic_margin=5.0)
    ROBUST = dict(M=5, R=11, conv_thresh=0.6, bic_margin=7.5)

    results = []

    for g in g_list:
        cache: Dict[float, Tuple[str, Optional[FitResult]]] = {}

        # === SCANSIONE GROSSOLANA SU s
        for s in s_grid:
            kind, _ = compute_fit_for_s_with_cache(
                master_key, g, s, N_list, cfg, cache=cache, **COARSE
            )
            print(f"[g={g:.2f}] s={s:.2f} -> {kind}")

        # === BRACKET: PRIMO poly con PRECEDENTE ∈ {exp, nonconv}
        p_lo, p_hi = first_poly_after_nonpoly(cache, s_grid)
        if (p_lo is not None) and (p_hi is not None) and (p_lo < p_hi):
            # rinforza i capi con valutazione più robusta
            reinforce_kind_at_points(master_key, g, [p_lo, p_hi], N_list, cfg, cache,
                                     M_strong=ROBUST["M"], R_strong=ROBUST["R"],
                                     bic_margin_strong=ROBUST["bic_margin"],
                                     conv_thresh=ROBUST["conv_thresh"])

            # ricontrolla che resti 'poly' su p_hi e non 'poly' su p_lo dopo il rinforzo
            kind_lo, _ = cache[p_lo]
            kind_hi, _ = cache[p_hi]
            if (kind_hi != "poly") or (kind_lo not in ("exp", "nonconv")):
                # se il rinforzo ha cambiato classificazione, riprova a cercare un nuovo bracket
                p_lo, p_hi = first_poly_after_nonpoly(cache, s_grid)

            if (p_lo is not None) and (p_hi is not None) and (p_lo < p_hi):
                s_plus = binary_search_s_plus(
                    master_key, g, N_list, cfg,
                    lo=p_lo, hi=p_hi,
                    cache=cache, s_tol=0.01, max_steps=12, **ROBUST
                )
                print(f"[g={g:.2f}] s+ ≈ {s_plus:.4f}  (first-poly bracket: [{p_lo:.3f}, {p_hi:.3f}])")
                results.append((g, float(s_plus), p_lo, p_hi, "ok_firstpoly"))
                continue

        # nessun bracket trovato con la regola desiderata
        print(f"[g={g:.2f}] Nessun bracket trovato.")
        results.append((g, None, None, None, "no_bracket"))

    # === RIEPILOGO
    print("\n=== s+(g) summary ===")
    for g, s_plus, lo, hi, status in results:
        if s_plus is not None:
            print(f"g={g:.2f}  s+≈{s_plus:.4f}   bracket=[{lo:.3f}, {hi:.3f}]  ({status})")
        else:
            print(f"g={g:.2f}  s+=N/A  ({status})")


if __name__ == "__main__":
    main()


[g=0.00] s=0.00 -> poly
[g=0.00] s=0.25 -> poly
[g=0.00] s=0.50 -> poly
[g=0.00] s=0.75 -> poly
[g=0.00] s=1.00 -> nonconv
[g=0.00] s=1.25 -> poly
[g=0.00] s=1.50 -> poly
[g=0.00] s=1.75 -> poly
[g=0.00] s=2.00 -> poly
[g=0.00] s=2.25 -> poly
[g=0.00] s=2.50 -> poly
[g=0.00] s=2.75 -> poly
[g=0.00] s=3.00 -> poly
[g=0.00] s=3.25 -> poly
[g=0.00] s=3.50 -> poly
[g=0.00] s=3.75 -> poly
[g=0.00] s=4.00 -> poly
[g=0.00] s=4.25 -> poly
[g=0.00] s=4.50 -> poly
[g=0.00] s+ ≈ 1.0039  (first-poly bracket: [1.000, 1.250])
[g=0.25] s=0.00 -> poly
[g=0.25] s=0.25 -> poly
[g=0.25] s=0.50 -> poly
[g=0.25] s=0.75 -> poly
[g=0.25] s=1.00 -> nonconv
[g=0.25] s=1.25 -> nonconv
[g=0.25] s=1.50 -> poly
[g=0.25] s=1.75 -> poly
[g=0.25] s=2.00 -> poly
[g=0.25] s=2.25 -> poly
[g=0.25] s=2.50 -> poly
[g=0.25] s=2.75 -> poly
[g=0.25] s=3.00 -> poly
[g=0.25] s=3.25 -> poly
[g=0.25] s=3.50 -> poly
[g=0.25] s=3.75 -> poly
[g=0.25] s=4.00 -> poly
[g=0.25] s=4.25 -> poly
[g=0.25] s=4.50 -> poly
[g=0.25] s+ ≈ 1.2695