In [None]:
# phaseA_discover.py
import numpy as np
from dataclasses import dataclass
from scipy.signal import savgol_filter
from sklearn.linear_model import LassoLars, LinearRegression

# ---------------------- helpers: smoothing & derivative ----------------------
def estimate_derivative(t, x, window=11, poly=3):
    t = np.asarray(t)
    x = np.asarray(x)
    if x.ndim == 1:
        x = x[:, None]
    dt = np.median(np.diff(t))
    window = max(window, poly + 3)
    if window % 2 == 0:
        window += 1
    x_smooth = savgol_filter(x, window_length=window, polyorder=poly, axis=0, mode="mirror")
    dxdt = savgol_filter(x, window_length=window, polyorder=poly, deriv=1, delta=dt, axis=0, mode="mirror")
    return x_smooth, dxdt

# ---------------------- weak/integral regression matrix ----------------------
def build_weak_form_system(t, Phi, y, win=10, stride=5):
    """
    Build A, b from windowed integrals:
      x(t_{i+w}) - x(t_i) ≈ ∫_{t_i}^{t_{i+w}} Phi(t) dt · theta
    """
    t = np.asarray(t); Phi = np.asarray(Phi); y = np.asarray(y).ravel()
    dt = np.diff(t)
    rows = []
    rhs  = []
    for i in range(0, len(t) - win - 1, stride):
        j = i + win
        # trapezoidal integral for each feature
        integ = np.trapezoid(Phi[i:j+1, :], t[i:j+1], axis=0)  # shape (p,)
        rows.append(integ)
        rhs.append(y[j] - y[i])  # exact delta from raw y (or use smoothed)
    A = np.vstack(rows) if rows else np.zeros((0, Phi.shape[1]))
    b = np.asarray(rhs).ravel()
    return A, b

# ---------------------- sparse regression with sign constraints --------------
def fit_sparse_with_signs(Phi, y, signs, alpha=1e-3, thresh=1e-4, max_iter=8, ridge=1e-8):
    """
    Sequential thresholded least squares with per-column sign constraints.
    signs: list like [None, '>=0', '<=0', ...] aligned to Phi columns.
    """
    Phi = np.asarray(Phi); y = np.asarray(y).ravel()
    n, p = Phi.shape
    if p == 0:
        return np.zeros(0), np.array([], dtype=bool)

    # Flip columns so constrained coefficients are nonnegative; track flip.
    flip = np.ones(p)
    pos_mask = np.zeros(p, dtype=bool)
    for j, s in enumerate(signs):
        if s in (">=0", ">="):
            pos_mask[j] = True
        elif s in ("<=0", "<="):
            pos_mask[j] = True
            flip[j] = -1.0
            Phi[:, j] *= -1.0

    coef = np.zeros(p)
    active = np.ones(p, dtype=bool)

    for _ in range(max_iter):
        Aidx = np.where(active)[0]
        if len(Aidx) == 0:
            break
        PhiA = Phi[:, Aidx]
        Apos = np.isin(Aidx, np.where(pos_mask)[0])

        coefA = np.zeros(PhiA.shape[1])

        # Free subset
        if (~Apos).any():
            Phi_free = PhiA[:, ~Apos]
            if Phi_free.shape[1]:
                l_free = LassoLars(alpha=alpha, fit_intercept=False)
                l_free.fit(Phi_free, y)
                coefA[~Apos] = l_free.coef_

        # Constrained (>=0 after flip)
        if Apos.any():
            Phi_pos = PhiA[:, Apos]
            y_resid = y - (PhiA[:, ~Apos] @ coefA[~Apos] if (~Apos).any() else 0.0)
            if Phi_pos.shape[1]:
                l_pos = LassoLars(alpha=alpha, fit_intercept=False, positive=True)
                l_pos.fit(Phi_pos, y_resid)
                beta = np.clip(l_pos.coef_, 0.0, None)
                # tiny ridge refine
                G = Phi_pos.T @ Phi_pos + ridge * np.eye(Phi_pos.shape[1])
                b = Phi_pos.T @ y_resid
                try:
                    beta = np.maximum(np.linalg.solve(G, b), 0.0)
                except np.linalg.LinAlgError:
                    pass
                coefA[Apos] = beta

        coef = np.zeros_like(coef)
        coef[Aidx] = coefA

        small = np.abs(coef) < thresh
        active = ~small
        if not active.any():
            break

    # Unflip for '<=0'
    coef *= flip
    support = np.abs(coef) >= thresh

    # Final unbiased refit on support (respect sign by projection)
    if support.any():
        lr = LinearRegression(fit_intercept=False).fit(Phi[:, support], y)
        coef_final = np.zeros_like(coef)
        coef_final[support] = lr.coef_
        for j, s in enumerate(signs):
            if s in (">=0", ">="):
                coef_final[j] = max(coef_final[j], 0.0)
            elif s in ("<=0", "<="):
                coef_final[j] = min(coef_final[j], 0.0)
        return coef_final, support
    return coef, support

# ---------------------- candidate library for dT/dt --------------------------
@dataclass
class Term:
    name: str
    expr: callable   # lambda(states, inputs) -> (n,)
    sign: str | None # '>=0', '<=0', or None
    group: str | None = None

def sat(z, K=0.5):
    return z / (K + np.abs(z) + 1e-12)

def tumor_library_spec(KT=1.0, Ksat=0.5):
    return [
        Term('T(1-T/K)',    lambda S,I: S['T']*(1.0 - S['T']/KT),     '>=0', 'growth'),
        Term('ER*T',        lambda S,I: S['ER']*S['T'],               None,  'hormone'),
        Term('PI3K*T',      lambda S,I: S['PI3K']*S['T'],             None,  'pi3k'),
        Term('-E*T/(K+T)',  lambda S,I: - S['E']*(S['T']/(Ksat+S['T']+1e-12)), '<=0', 'immune'),
        Term('-Adr*Ki67*T', lambda S,I: - I['Adr']*S['Ki67']*S['T'],  '<=0', 'chemo'),
        Term('-Tax*Ki67*T', lambda S,I: - I['Tax']*S['Ki67']*S['T'],  '<=0', 'chemo'),
        Term('-Tam*ER*T',   lambda S,I: - I['Tam']*S['ER']*S['T'],    '<=0', 'endo'),
    ]

def build_library(terms, states, inputs):
    cols, names, signs, groups = [], [], [], []
    for tm in terms:
        v = tm.expr(states, inputs).reshape(-1, 1)
        v = np.nan_to_num(v, nan=0.0, posinf=0.0, neginf=0.0)
        cols.append(v); names.append(tm.name); signs.append(tm.sign); groups.append(tm.group)
    Phi = np.hstack(cols) if cols else np.zeros((len(next(iter(states.values()))), 0))
    return Phi, names, signs, groups

# ---------------------- synthetic data (LV + therapy pulses) -----------------
def simulate_synthetic(t):
    """
    Simple LV with logistic tumor + two drug pulses; ER, PI3K, Ki67 are smooth drivers.
    This is ONLY to sanity-check the pipeline.
    """
    dt = np.median(np.diff(t))
    T = np.zeros_like(t); E = np.zeros_like(t)
    ER = 0.6 + 0.1*np.sin(2*np.pi*t/60.0)        # mild oscillation
    PI3K = 0.4 + 0.1*np.cos(2*np.pi*t/80.0)
    Ki67 = 0.5 + 0.15*np.tanh(0.02*(t-40))       # rises mid-course

    Adr = np.exp(-0.25*np.maximum(t-20, 0)) * (t>=20)  # pulse ~ day 20
    Tax = np.exp(-0.25*np.maximum(t-70, 0)) * (t>=70)  # pulse ~ day 70
    Tam = 0.0 * t                                     # off in this toy

    # true params (ground truth we hope to recover sparsely)
    KT=1.0; Ksat=0.5
    a_ER=0.6; a_PI3K=0.2
    k_immune=0.8
    k_Adr=0.9; k_Tax=0.7

    T[0]=0.2; E[0]=0.5
    for i in range(len(t)-1):
        tt = t[i]
        growth = T[i]*(1 - T[i]/KT) + a_ER*ER[i]*T[i] + a_PI3K*PI3K[i]*T[i]
        kill_immune = - k_immune * E[i]*(T[i]/(Ksat+T[i]))
        kill_drug = - (k_Adr*Adr[i] + k_Tax*Tax[i]) * Ki67[i]*T[i]
        dT = growth + kill_immune + kill_drug

        dE = 0.15*E[i]*(T[i]/(0.6+T[i])) - 0.12*E[i]  # harmless E dyn for realism

        T[i+1] = np.clip(T[i] + dt*dT, 1e-6, 2.0)
        E[i+1] = np.clip(E[i] + dt*dE, 1e-6, 2.0)

    # add light noise
    rng = np.random.default_rng(0)
    Tn = np.clip(T + 0.01*rng.normal(size=T.shape), 1e-6, None)
    En = np.clip(E + 0.01*rng.normal(size=E.shape), 1e-6, None)

    states = {'T':Tn, 'E':En, 'ER':ER, 'PI3K':PI3K, 'Ki67':Ki67}
    inputs = {'Adr':Adr, 'Tax':Tax, 'Tam':Tam}
    return states, inputs

# ---------------------- main: run both discovery modes -----------------------
if __name__ == "__main__":
    # time grid
    t = np.linspace(0, 120, 241)  # 0..120 days, 0.5d step

    # synthetic "patient"
    states, inputs = simulate_synthetic(t)

    # Build candidate features for dT/dt
    terms = tumor_library_spec(KT=1.0, Ksat=0.5)
    Phi, names, signs, groups = build_library(terms, states, inputs)

    # ---------- Option A: SG + Sparse (classic) ----------
    Ts, dTdt = estimate_derivative(t, states['T'], window=21, poly=3)
    yA = dTdt[:, 0]
    coefA, suppA = fit_sparse_with_signs(Phi, yA, signs, alpha=2e-3, thresh=2e-4)
    print("\n=== Phase A (SG + sparse) discovered dT/dt ===")
    for nm, c, on in zip(names, coefA, suppA):
        if on:
            print(f"{nm:<18s} -> {c:+.4f}")

    # ---------- Option B: Weak/Integral form (no derivative) ----------
    A, b = build_weak_form_system(t, Phi, states['T'], win=12, stride=4)
    coefB, suppB = fit_sparse_with_signs(A, b, signs, alpha=2e-3, thresh=2e-4)
    print("\n=== Phase A (weak/integral) discovered dT/dt ===")
    for nm, c, on in zip(names, coefB, suppB):
        if on:
            print(f"{nm:<18s} -> {c:+.4f}")



=== Phase A (SG + sparse) discovered dT/dt ===
T(1-T/K)           -> +0.1318
ER*T               -> +0.0658
-E*T/(K+T)         -> +0.0000

=== Phase A (weak/integral) discovered dT/dt ===
T(1-T/K)           -> +0.3616
ER*T               -> +0.0714
PI3K*T             -> +0.2291
-Adr*Ki67*T        -> +0.0000
-Tax*Ki67*T        -> +0.0000
