<a href="https://colab.research.google.com/github/avbarbaros/Hybrid_Neural_ODE_Lotka-Volterra_Clean_Replication/blob/main/Predator_Prey_Hybrid_Neural_ODE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torchdiffeq -q

In [7]:
"""
================================================================================
Hybrid Neural ODE: Lotka–Volterra Clean Replication
================================================================================

Single experiment: matched-model, noise-free verification on the classical
predator–prey system. Tests whether ρ → 0 and parameters are recovered
in a 2D bilinear system from population dynamics.

  ẋ = αx − βxy       (prey)
  ẏ = δxy − γy       (predator)

Ground truth: α=1.0, β=0.5, δ=0.5, γ=2.0

================================================================================
"""

# !pip install torchdiffeq -q

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
import time, os, warnings
warnings.filterwarnings('ignore')

# ── Configuration ─────────────────────────────────────────────────────────────
torch.set_default_dtype(torch.float64)

CONFIG = {
    'alpha_true': 1.0,
    'beta_true':  0.5,
    'delta_true': 0.5,
    'gamma_true': 2.0,

    'T_horizon': 30.0,
    'N_train': 2000,
    'N_val': 1000,
    'seed': 42,

    'hidden_units': 32,
    'batch_size': 8,
    'seq_len': 80,
    'epochs': 60,
    'lr': 1e-3,
    'lambda_l1': 1e-2,
    'grad_clip': 1.0,

    'rtol': 1e-7,
    'atol': 1e-9,
    'solver': 'dopri5',
}

OUT = './lv_results'
os.makedirs(OUT, exist_ok=True)
torch.manual_seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])

from torchdiffeq import odeint
print(f"✓ torchdiffeq, dtype={torch.get_default_dtype()}")


# ── Equilibria ────────────────────────────────────────────────────────────────
def print_equilibria():
    a, b, d, g = CONFIG['alpha_true'], CONFIG['beta_true'], CONFIG['delta_true'], CONFIG['gamma_true']
    xs, ys = g/d, a/b
    J = np.array([[0, -b*xs], [d*ys, 0]])
    eigs = np.linalg.eigvals(J)
    print(f"\n  Equilibria (α={a}, β={b}, δ={d}, γ={g}):")
    print(f"    E0=(0,0) → saddle")
    print(f"    E*=({xs:.1f},{ys:.1f}) → center, eigenvalues={eigs.round(4)}")
    print(f"    Period ≈ {2*np.pi/abs(eigs[0].imag):.2f} t.u.")

print_equilibria()


# ── Data Generation ───────────────────────────────────────────────────────────
def generate_trajectory(n_points, seed_offset=0):
    rng = np.random.RandomState(CONFIG['seed'] + seed_offset)
    x0 = [rng.uniform(2.0, 6.0), rng.uniform(1.0, 3.0)]
    t_eval = np.linspace(0, CONFIG['T_horizon'], n_points)
    sol = solve_ivp(
        lambda t, s: [CONFIG['alpha_true']*s[0] - CONFIG['beta_true']*s[0]*s[1],
                      CONFIG['delta_true']*s[0]*s[1] - CONFIG['gamma_true']*s[1]],
        (0, CONFIG['T_horizon']), x0, method='RK45', t_eval=t_eval,
        rtol=CONFIG['rtol'], atol=CONFIG['atol'])
    return (torch.tensor(sol.t, dtype=torch.float64),
            torch.tensor(sol.y.T, dtype=torch.float64), np.array(x0))


# ── Model ─────────────────────────────────────────────────────────────────────
class CorrectionNetwork(nn.Module):
    def __init__(self, h=CONFIG['hidden_units']):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(2, h), nn.ReLU(), nn.Linear(h, 2))
        nn.init.zeros_(self.net[2].bias)
    def forward(self, x):
        return self.net(x)


class HybridLV(nn.Module):
    def __init__(self):
        super().__init__()
        self.learn_params = True
        rng = np.random.RandomState(CONFIG['seed'] + 7)
        self.alpha = nn.Parameter(torch.tensor(CONFIG['alpha_true'] + rng.uniform(-0.1, 0.1)))
        self.beta  = nn.Parameter(torch.tensor(CONFIG['beta_true']  + rng.uniform(-0.05, 0.05)))
        self.delta = nn.Parameter(torch.tensor(CONFIG['delta_true'] + rng.uniform(-0.05, 0.05)))
        self.gamma = nn.Parameter(torch.tensor(CONFIG['gamma_true'] + rng.uniform(-0.2, 0.2)))
        self.correction = CorrectionNetwork()

    def mechanistic(self, x):
        prey, pred = x[..., 0], x[..., 1]
        dx = self.alpha * prey - self.beta * prey * pred
        dy = self.delta * prey * pred - self.gamma * pred
        return torch.stack([dx, dy], dim=-1)

    def forward(self, t, x):
        return self.mechanistic(x) + self.correction(x)


# ── Training ──────────────────────────────────────────────────────────────────
def train(model, t_tr, traj_tr, t_va, traj_va):
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
    rng = np.random.RandomState(CONFIG['seed'] + 42)
    sl = CONFIG['seq_len']

    history = {'train_loss': [], 'val_loss': [],
               'alpha': [], 'beta': [], 'delta': [], 'gamma': []}

    for epoch in range(CONFIG['epochs']):
        model.train()
        max_start = max(1, len(t_tr) - sl)
        indices = rng.choice(max_start, size=CONFIG['batch_size'], replace=True)
        ep_loss = []

        for idx in indices:
            t_w = t_tr[idx:idx+sl]; traj_w = traj_tr[idx:idx+sl]
            optimizer.zero_grad()
            try:
                pred = odeint(model, traj_w[0], t_w - t_w[0],
                              rtol=CONFIG['rtol'], atol=CONFIG['atol'], method=CONFIG['solver'])
                mse = torch.mean((pred - traj_w)**2)
                g_v = model.correction(traj_w)
                l1 = torch.mean(torch.norm(g_v, dim=-1))
                loss = mse + CONFIG['lambda_l1'] * l1
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
                optimizer.step()
                ep_loss.append(loss.item())
            except (RuntimeError, AssertionError):
                continue

        tr_l = np.mean(ep_loss) if ep_loss else float('nan')

        # Validate
        model.eval()
        with torch.no_grad():
            vi = rng.choice(max(1, len(t_va)-sl), size=4, replace=True)
            vl = []
            for v in vi:
                t_vw = t_va[v:v+sl]; traj_vw = traj_va[v:v+sl]
                try:
                    vp = odeint(model, traj_vw[0], t_vw-t_vw[0],
                                rtol=CONFIG['rtol'], atol=CONFIG['atol'], method=CONFIG['solver'])
                    vl.append(torch.mean((vp - traj_vw)**2).item())
                except:
                    pass
            va_l = np.mean(vl) if vl else float('nan')

        scheduler.step(va_l if not np.isnan(va_l) else 1e6)
        history['train_loss'].append(tr_l); history['val_loss'].append(va_l)
        history['alpha'].append(model.alpha.item())
        history['beta'].append(model.beta.item())
        history['delta'].append(model.delta.item())
        history['gamma'].append(model.gamma.item())

        if (epoch+1) % 5 == 0 or epoch == 0:
            print(f"  Ep {epoch+1:3d}/{CONFIG['epochs']}: "
                  f"L_tr={tr_l:.4e} L_val={va_l:.4e} | "
                  f"α={model.alpha.item():.5f} β={model.beta.item():.5f} "
                  f"δ={model.delta.item():.5f} γ={model.gamma.item():.5f}")

    return history


# ── Diagnostics ───────────────────────────────────────────────────────────────
@torch.no_grad()
def diagnostics(model, t, traj, label):
    # Stitched prediction to avoid long-integration blowup
    chunk = 200
    preds = []
    for i in range(0, len(t), chunk):
        tc = t[i:i+chunk]; trajc = traj[i:i+chunk]
        p = odeint(model, trajc[0], tc - tc[0],
                   rtol=CONFIG['rtol'], atol=CONFIG['atol'], method=CONFIG['solver'])
        preds.append(p)
    pred = torch.cat(preds, dim=0)

    mse = torch.mean((pred - traj)**2).item()
    rmse = np.sqrt(mse)
    var_exp = max(0, (1 - mse/torch.var(traj).item())*100)

    g = model.correction(traj)
    f = model.mechanistic(traj)
    rho = (torch.mean(torch.norm(g, dim=-1)) / torch.mean(torch.norm(f, dim=-1))).item()
    g_abs = torch.mean(torch.abs(g), dim=0).numpy()

    print(f"\n  ┌─ {label}")
    print(f"  │  MSE:          {mse:.6e}")
    print(f"  │  RMSE:         {rmse:.6e}")
    print(f"  │  Var Explained: {var_exp:.4f}%")
    print(f"  │  ρ:            {rho:.6f}")
    print(f"  │  E[||g||]:     {torch.mean(torch.norm(g,dim=-1)).item():.6e}")
    print(f"  │  E[|Δprey|]:   {g_abs[0]:.6e}")
    print(f"  │  E[|Δpred|]:   {g_abs[1]:.6e}")
    print(f"  └{'─'*50}")

    return {'mse': mse, 'rmse': rmse, 'var_exp': var_exp, 'rho': rho,
            'pred': pred.numpy(), 'g': g.numpy(), 'g_abs': g_abs}


# ── Figures ───────────────────────────────────────────────────────────────────
def make_figures(history, model, t_tr, traj_tr, t_va, traj_va, d_tr, d_va):

    # 1. Convergence
    fig, ax = plt.subplots(figsize=(7, 4.5))
    ep = range(1, len(history['train_loss'])+1)
    ax.semilogy(ep, history['train_loss'], 'b-', lw=2, label='train')
    ax.semilogy(ep, history['val_loss'], 'orange', lw=2, label='val')
    ax.set_xlabel('Epoch'); ax.set_ylabel('Loss (log)')
    ax.set_title('Training Convergence'); ax.legend(); ax.grid(True, alpha=0.3)
    plt.tight_layout(); plt.savefig(f'{OUT}/convergence.png', dpi=150); plt.close()

    # 2. Parameter convergence
    fig, ax = plt.subplots(figsize=(8, 5))
    for key, tv, lab in [('alpha',1.0,'α'),('beta',0.5,'β'),('delta',0.5,'δ'),('gamma',2.0,'γ')]:
        ax.plot(ep, history[key], lw=2, label=f'{lab} (true={tv})')
        ax.axhline(tv, color=ax.get_lines()[-1].get_color(), ls='--', alpha=0.4)
    ax.set_xlabel('Epoch'); ax.set_ylabel('Value')
    ax.set_title('Parameter Convergence'); ax.legend(); ax.grid(True, alpha=0.3)
    plt.tight_layout(); plt.savefig(f'{OUT}/params.png', dpi=150); plt.close()

    # 3. Trajectories (train + val)
    for t, traj, pred, name in [(t_tr, traj_tr, d_tr['pred'], 'train'),
                                 (t_va, traj_va, d_va['pred'], 'val')]:
        tn = t.numpy(); trajn = traj.numpy()
        fig, axes = plt.subplots(2, 1, figsize=(10, 5), sharex=True)
        for i, lab in enumerate(['Prey x(t)', 'Predator y(t)']):
            axes[i].plot(tn, trajn[:,i], 'b-', lw=1.5, label='true')
            axes[i].plot(tn, pred[:,i], 'r--', lw=1.5, alpha=0.8, label='pred')
            axes[i].set_ylabel(lab); axes[i].legend(); axes[i].grid(True, alpha=0.3)
        axes[-1].set_xlabel('t')
        fig.suptitle(f'Trajectories ({name})')
        plt.tight_layout(); plt.savefig(f'{OUT}/traj_{name}.png', dpi=150); plt.close()

    # 4. Phase portrait
    fig, ax = plt.subplots(figsize=(7, 6))
    tn = traj_tr.numpy()
    ax.plot(tn[:,0], tn[:,1], 'b-', lw=1.5, label='True', alpha=0.7)
    ax.plot(d_tr['pred'][:,0], d_tr['pred'][:,1], 'r--', lw=1.5, label='Predicted', alpha=0.7)
    xs, ys = CONFIG['gamma_true']/CONFIG['delta_true'], CONFIG['alpha_true']/CONFIG['beta_true']
    ax.plot(xs, ys, 'k*', ms=12, label=f'E*=({xs:.0f},{ys:.0f})')
    ax.set_xlabel('Prey'); ax.set_ylabel('Predator')
    ax.set_title('Phase Portrait'); ax.legend(); ax.grid(True, alpha=0.3)
    plt.tight_layout(); plt.savefig(f'{OUT}/phase.png', dpi=150); plt.close()

    # 5. Corrections
    tn = t_tr.numpy(); g = d_tr['g']
    gn = np.linalg.norm(g, axis=-1)
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    axes[0].plot(tn, gn, 'b-', lw=1); axes[0].set_xlabel('t'); axes[0].set_ylabel('||g||')
    axes[0].set_title(f'||g(x(t))||  (ρ={d_tr["rho"]:.4f})')
    axes[1].plot(tn, g[:,0], lw=1, label='Δprey')
    axes[1].plot(tn, g[:,1], lw=1, label='Δpred')
    axes[1].set_xlabel('t'); axes[1].set_title('Components'); axes[1].legend()
    axes[2].hist(gn, bins=50, edgecolor='k', alpha=0.7)
    axes[2].set_xlabel('||g||'); axes[2].set_title('Histogram')
    stats = f'Mean={np.mean(gn):.3e}\nMax={np.max(gn):.3e}'
    axes[2].text(0.95, 0.95, stats, transform=axes[2].transAxes, va='top', ha='right',
                 fontsize=9, bbox=dict(boxstyle='round', fc='wheat', alpha=0.5))
    for ax in axes: ax.grid(True, alpha=0.3)
    fig.suptitle('Learned Corrections (Train)')
    plt.tight_layout(); plt.savefig(f'{OUT}/corrections.png', dpi=150); plt.close()

    print(f"  Figures saved to {OUT}/")


# ── Main ──────────────────────────────────────────────────────────────────────
if __name__ == '__main__':
    print("\n" + "═"*60)
    print("  Hybrid Neural ODE — Lotka–Volterra Clean Replication")
    print("═"*60)

    t_tr, traj_tr, ic_tr = generate_trajectory(CONFIG['N_train'], 0)
    t_va, traj_va, ic_va = generate_trajectory(CONFIG['N_val'], 100)
    print(f"  IC train: {ic_tr.round(3)},  IC val: {ic_va.round(3)}")

    model = HybridLV()
    print(f"  Init: α={model.alpha.item():.4f} β={model.beta.item():.4f} "
          f"δ={model.delta.item():.4f} γ={model.gamma.item():.4f}")
    print(f"  Trainable params: {sum(p.numel() for p in model.parameters())}")

    t0 = time.time()
    history = train(model, t_tr, traj_tr, t_va, traj_va)
    print(f"\n  Training time: {time.time()-t0:.1f}s")

    model.eval()
    d_tr = diagnostics(model, t_tr, traj_tr, "Train")
    d_va = diagnostics(model, t_va, traj_va, "Val")

    # Parameter recovery table
    print(f"\n  {'─'*55}")
    print(f"  {'Param':<6} {'True':>8} {'Learned':>12} {'Error%':>10}")
    print(f"  {'─'*55}")
    for nm, tv, attr in [('α',1.0,'alpha'),('β',0.5,'beta'),
                          ('δ',0.5,'delta'),('γ',2.0,'gamma')]:
        lv = getattr(model, attr).item()
        print(f"  {nm:<6} {tv:>8.4f} {lv:>12.6f} {abs(lv-tv)/tv*100:>10.4f}")

    # Summary
    print(f"\n  {'─'*55}")
    print(f"  {'Metric':<20} {'Train':>14} {'Val':>14}")
    print(f"  {'─'*55}")
    print(f"  {'MSE':<20} {d_tr['mse']:>14.4e} {d_va['mse']:>14.4e}")
    print(f"  {'RMSE':<20} {d_tr['rmse']:>14.4e} {d_va['rmse']:>14.4e}")
    print(f"  {'Var Explained':<20} {d_tr['var_exp']:>13.4f}% {d_va['var_exp']:>13.4f}%")
    print(f"  {'ρ':<20} {d_tr['rho']:>14.6f} {d_va['rho']:>14.6f}")

    make_figures(history, model, t_tr, traj_tr, t_va, traj_va, d_tr, d_va)
    print("\n  ✓ Done.")

✓ torchdiffeq, dtype=torch.float64

  Equilibria (α=1.0, β=0.5, δ=0.5, γ=2.0):
    E0=(0,0) → saddle
    E*=(4.0,2.0) → center, eigenvalues=[0.+1.4142j 0.-1.4142j]
    Period ≈ 4.44 t.u.

════════════════════════════════════════════════════════════
  Hybrid Neural ODE — Lotka–Volterra Clean Replication
════════════════════════════════════════════════════════════
  IC train: [3.498 2.901],  IC val: [5.608 2.116]
  Init: α=0.9602 β=0.4747 δ=0.5426 γ=2.1566
  Trainable params: 166
  Ep   1/60: L_tr=5.1959e-01 L_val=1.0052e+00 | α=0.95222 β=0.48266 δ=0.54298 γ=2.15625
  Ep   5/60: L_tr=1.3140e-02 L_val=5.6715e-04 | α=0.92071 β=0.51297 δ=0.53267 γ=2.16248
  Ep  10/60: L_tr=4.3485e-03 L_val=2.3788e-04 | α=0.92514 β=0.51013 δ=0.53522 γ=2.15756
  Ep  15/60: L_tr=3.3410e-03 L_val=2.6371e-04 | α=0.93035 β=0.50615 δ=0.53537 γ=2.15656
  Ep  20/60: L_tr=2.3592e-03 L_val=2.8095e-04 | α=0.93730 β=0.50116 δ=0.53563 γ=2.15581
  Ep  25/60: L_tr=2.0341e-03 L_val=5.4494e-04 | α=0.94484 β=0.49623 δ=0.53634