In [18]:
import numpy as np
import scipy.sparse as sp
from time import perf_counter

def softplus_beta(x, beta=2.0):
    t = beta * x
    return np.where(t > 20.0, x, np.log1p(np.exp(np.clip(t, -50, 50))) / beta)

def softplus_inv_beta(y, beta=2.0, eps=1e-12):
    y = np.maximum(y, 0.0)
    t = beta * y
    large = (y > 10.0 / beta)
    x_large = y + (1.0/beta)*np.log1p(-np.exp(-beta*y) + eps)
    x_small = (1.0/beta)*np.log(np.expm1(t) + eps)
    return np.where(large, x_large, x_small)

def sigmoid_beta(x, beta=2.0):
    z = np.clip(beta * x, -60, 60)
    return 1.0 / (1.0 + np.exp(-z))

def tri_laplacian(P, w=0.4):
    main = -2*np.ones(P)
    off  = np.ones(P-1)
    L = sp.diags([off, main, off], [-1, 0, 1], shape=(P, P), format='csr')
    return (w/4.0) * L

def R_lotka_volterra(X, a=1.0, b=0.5, c=0.5, d=0.1, r_rest=0.3, K_rest=1.0):
    P, S = X.shape
    XC = np.maximum(X, 1e-10)
    F = np.zeros_like(XC)
    if S >= 2:
        Prey = XC[:, 0:1]
        Pred = XC[:, 1:2]
        F[:, 0:1] = a * Prey - b * Prey * Pred
        F[:, 1:2] = -c * Pred + d * Prey * Pred
    if S > 2:
        Rest = XC[:, 2:]
        F[:, 2:] = r_rest * Rest * (1.0 - Rest / K_rest)
    return F

def right_mul_dense_sparse(X, B):
    return (B.T @ X.T).T if sp.issparse(B) else X @ B

def select_evenly(n, k):
    k = int(k)
    k = max(1, min(k, n))
    return np.linspace(0, n-1, k, dtype=int)

class SoftplusSkODE:
    def __init__(self, L, B, R_func, X0, k, beta=2.0, use_rebalance=True, c2_eps=1e-12):
        self.L, self.B, self.R_func = L, B, R_func
        self.P, self.S = X0.shape
        self.k = int(min(max(1, k), self.P, self.S))
        self.beta = beta
        self.use_rebalance = bool(use_rebalance)
        self.c2_eps = float(c2_eps)
        self.I = select_evenly(self.P, self.k)
        self.J = select_evenly(self.S, self.k)
        self._init_weights(X0)

    def _init_weights(self, X0):
        eps = 1e-6
        C0 = np.maximum(X0[:, self.J], eps)
        R0 = np.maximum(X0[self.I, :], eps)
        S0 = np.maximum(X0[np.ix_(self.I, self.J)], eps)
        self.W_C = softplus_inv_beta(C0, self.beta)
        self.W_R = softplus_inv_beta(R0, self.beta)
        C_I, R_J = C0[self.I, :], R0[:, self.J]
        try:
            U0 = np.linalg.solve(C_I + 1e-8*np.eye(self.k), S0)
            U0 = np.linalg.solve((R_J + 1e-8*np.eye(self.k)).T, U0.T).T
            U0 = np.maximum(U0, eps)
        except np.linalg.LinAlgError:
            U0 = np.maximum(np.eye(self.k), eps)
        self.W_U = softplus_inv_beta(U0, self.beta)

    def factors(self):
        C = softplus_beta(self.W_C, self.beta)
        U = softplus_beta(self.W_U, self.beta)
        R = softplus_beta(self.W_R, self.beta)
        return C, U, R

    def _anchors(self, C, R):
        LC = self.L @ C
        RB = right_mul_dense_sparse(R, self.B)
        return LC, RB

    def _rhs_slices(self, C, U, R, LC, RB):
        I, J = self.I, self.J
        X_cJ = C @ (U @ R[:, J])
        X_rI = (C[I, :] @ U) @ R
        S    = (C[I, :] @ U) @ R[:, J]
        F_cJ_lin = (LC @ (U @ R[:, J])) + (C @ (U @ RB[:, J]))
        F_rI_lin = (LC[I, :] @ (U @ R)) + (C[I, :] @ (U @ RB))
        F_S_lin  = (LC[I, :] @ U @ R[:, J]) + (C[I, :] @ U @ RB[:, J])
        R_CJ = self.R_func(X_cJ)[:, :len(J)]
        R_RI = self.R_func(X_rI)
        R_S  = self.R_func(S)[:, :len(J)]
        F_CJ = F_cJ_lin + R_CJ
        F_RI = F_rI_lin + R_RI
        F_S  = F_S_lin  + R_S
        return F_CJ, F_RI, S, F_S

    def _update_WU_on_skeleton(self, C_I, R_J, S_tgt, iters=4, lr=0.2, l2=1e-6, clip=5.0):
        W = self.W_U
        for _ in range(iters):
            U = softplus_beta(W, self.beta)
            E = C_I @ U @ R_J - S_tgt
            G = C_I.T @ E @ R_J.T
            dW = G * (sigmoid_beta(W, self.beta)/self.beta) + l2*W
            dW = np.clip(dW, -clip, clip)
            W = np.clip(W - lr*dW, -50.0, 50.0)
        self.W_U = W

    def _rebalance(self):
        C, U, R = self.factors()
        eps = 1e-12
        cn = np.maximum(np.linalg.norm(C, axis=0), eps)
        rn = np.maximum(np.linalg.norm(R, axis=1), eps)
        C = C / cn
        R = R / rn[:, None]
        U = (np.diag(cn) @ U) @ np.diag(rn)
        C = np.maximum(C, 0.0)
        R = np.maximum(R, 0.0)
        U = np.maximum(U, eps)
        self.W_C = softplus_inv_beta(C, self.beta)
        self.W_R = softplus_inv_beta(R, self.beta)
        self.W_U = softplus_inv_beta(U, self.beta)

    def step_sci(self, dt, rho=0.25, max_backtracks=8, wu_iters=4, wu_lr=0.2, wu_l2=1e-6, wu_clip=2.0):
        C0, U0, R0 = self.factors()
        LC0, RB0 = self._anchors(C0, R0)
        F_CJ, F_RI, S0, F_S = self._rhs_slices(C0, U0, R0, LC0, RB0)
        def growth(new, old, eps=self.c2_eps):
            den = np.maximum(np.abs(old), eps)
            return float(np.max(np.abs(new - old) / den))
        dt_try = dt
        bt = 0
        for q in range(max_backtracks + 1):
            Cn_prop = C0 + dt_try * F_CJ
            Rn_prop = R0 + dt_try * F_RI
            Sn_prop = S0 + dt_try * F_S
            ok1 = (np.min(Cn_prop) >= 0.0 and np.min(Rn_prop) >= 0.0 and np.min(Sn_prop) >= 0.0)
            gC = growth(Cn_prop, C0)
            gR = growth(Rn_prop, R0)
            gS = growth(Sn_prop, S0)
            ok2 = (gC <= rho and gR <= rho and gS <= rho)
            if ok1 and ok2:
                Cn, Rn, Sn = Cn_prop, Rn_prop, Sn_prop
                gmax = max(gC, gR, gS)
                break
            dt_try *= 0.5
            bt += 1
        else:
            Cn, Rn, Sn = np.maximum(Cn_prop, 1e-10), np.maximum(Rn_prop, 1e-10), np.maximum(Sn_prop, 1e-10)
            gmax = max(growth(Cn, C0), growth(Rn, R0), growth(Sn, S0))
        Cn = np.maximum(Cn, 1e-10)
        Rn = np.maximum(Rn, 1e-10)
        Sn = np.maximum(Sn, 1e-10)
        self._update_WU_on_skeleton(Cn[self.I, :], Rn[:, self.J], Sn, iters=wu_iters, lr=wu_lr, l2=wu_l2, clip=wu_clip)
        self.W_C = softplus_inv_beta(Cn, self.beta)
        self.W_R = softplus_inv_beta(Rn, self.beta)
        if self.use_rebalance:
            self._rebalance()
        return float(dt_try), int(bt), float(gmax)

    def skeleton(self):
        C, U, R = self.factors()
        return (C[self.I, :] @ U) @ R[:, self.J]

    def reconstruct_full(self):
        C, U, R = self.factors()
        return (C @ U) @ R

def _sampled_min(sol, max_elems=200_000, r_cap=1000, c_cap=1000, seed=7):
    P, S = sol.P, sol.S
    if P * S <= max_elems:
        return float(sol.reconstruct_full().min())
    rng = np.random.default_rng(seed)
    r = min(P, r_cap)
    c = min(S, c_cap)
    rows = rng.choice(P, size=r, replace=False)
    cols = rng.choice(S, size=c, replace=False)
    C, U, R = sol.factors()
    Xs = (C[rows, :] @ U) @ R[:, cols]
    return float(np.min(Xs))

def _report_scales(sol):
    C, U, R = sol.factors()
    eps = 1e-15
    col_norms_C = np.linalg.norm(C, axis=0) + eps
    row_norms_R = np.linalg.norm(R, axis=1) + eps
    try:
        condU = np.linalg.cond(U)
    except np.linalg.LinAlgError:
        condU = np.inf
    return {
        "C_min_norm": float(np.min(col_norms_C)),
        "C_max_norm": float(np.max(col_norms_C)),
        "C_ratio_max/min": float(np.max(col_norms_C)/np.min(col_norms_C)),
        "R_min_norm": float(np.min(row_norms_R)),
        "R_max_norm": float(np.max(row_norms_R)),
        "R_ratio_max/min": float(np.max(row_norms_R)/np.min(row_norms_R)),
        "cond(U)": float(condU),
    }

def run_spcur_sci_experiment(P=100_000, S=500, k=20,
                             total_steps=100, w=0.1,
                             beta=2.0, rho=0.50,
                             seed=7, use_rebalance=True,
                             dt0=0.1, wu_iters=4, wu_lr=0.2, wu_l2=1e-6, wu_clip=2.0):
    np.random.seed(seed)
    X0 = 0.5 + 0.1 * np.random.rand(P, S)
    if S >= 1:
        X0[0:P//2, 0] *= 2.0
    L = tri_laplacian(P, w=w)
    B = sp.csr_matrix((S, S))
    sol = SoftplusSkODE(L, B, R_lotka_volterra, X0, k=k, beta=beta, use_rebalance=use_rebalance, c2_eps=1e-12)
    dt = float(dt0)
    mins, used_dts, bt_list, gmax_list = [], [], [], []
    t0 = perf_counter()
    for _ in range(total_steps):
        used, bt, gmax = sol.step_sci(dt, rho=rho, max_backtracks=10, wu_iters=wu_iters, wu_lr=wu_lr, wu_l2=wu_l2, wu_clip=wu_clip)
        used_dts.append(used)
        bt_list.append(bt)
        gmax_list.append(gmax)
        mins.append(_sampled_min(sol))
    t1 = perf_counter()
    C, U, R = sol.factors()
    num_params = C.size + U.size + R.size
    stats = _report_scales(sol)
    print("=" * 60)
    tag = "ON" if use_rebalance else "OFF"
    print(f"SPCUR-SCI (Scale Rebalance: {tag})")
    print("=" * 60)
    print(f"Model Scale: P={P} (locations), S={S} (species) | Full size: {P*S:,}")
    print(f"Low Rank: k={sol.k} (requested={k}) | Parameters: {num_params:,}")
    print("-" * 60)
    print("1. Scalability")
    print(f"Total time for {total_steps} steps: {(t1 - t0):.4f}s")
    print(f"(Solved {P}x{S} ODEs via rank-{sol.k} approximation)")
    print("\n2. Positivity")
    final_min = float(np.min(mins)) if len(mins) else float('nan')
    print(f"Estimated min(population) over sim: {final_min:.3e}")
    print("Kept ≥ 0" if final_min >= 0 else "Negative sample detected")
    med_dt = float(np.median(used_dts)) if len(used_dts) else float('nan')
    avg_bt = float(np.mean(bt_list)) if len(bt_list) else float('nan')
    med_g = float(np.median(gmax_list)) if len(gmax_list) else float('nan')
    print("\n3. Step behavior under C2")
    print(f"Initial dt guess: {dt:.4e} | Median used dt: {med_dt:.4e}")
    print(f"Median max-relative-growth: {med_g:.4e} (≤ rho={rho})")
    print(f"Avg backtracks per step: {avg_bt:.2f}")
    print("\n4. Scale & Conditioning")
    for k_, v_ in stats.items():
        if isinstance(v_, float):
            print(f"{k_:>18}: {v_:,.3e}")
        else:
            print(f"{k_:>18}: {v_}")

if __name__ == "__main__":
    run_spcur_sci_experiment(P=100_000, S=500, k=20, total_steps=100, w=0.1, beta=2.0, rho=0.50, seed=7, use_rebalance=True, dt0=0.1)


SPCUR-SCI (Scale Rebalance: ON)
Model Scale: P=100000 (locations), S=500 (species) | Full size: 50,000,000
Low Rank: k=20 (requested=20) | Parameters: 2,010,400
------------------------------------------------------------
1. Scalability
Total time for 100 steps: 24.7929s
(Solved 100000x500 ODEs via rank-20 approximation)

2. Positivity
Estimated min(population) over sim: 1.458e-10
Kept ≥ 0

3. Step behavior under C2
Initial dt guess: 1.0000e-01 | Median used dt: 1.0000e-01
Median max-relative-growth: 1.0040e-01 (≤ rho=0.5)
Avg backtracks per step: 0.22

4. Scale & Conditioning
        C_min_norm: 1.000e+00
        C_max_norm: 1.000e+00
   C_ratio_max/min: 1.000e+00
        R_min_norm: 1.000e+00
        R_max_norm: 1.000e+00
   R_ratio_max/min: 1.000e+00
           cond(U): 5.680e+295
