# Experiment 1 — Forward pass (§5.1)

Validates Theorem 1 (trajectory convergence, rate $O(h)$), compares
batch sampling schemes, and benchmarks runtime.

In [None]:
import os, sys

for _root in (os.getcwd(), os.path.abspath(os.path.join(os.getcwd(), ".."))):
    if os.path.isdir(os.path.join(_root, "rnode")) and _root not in sys.path:
        sys.path.insert(0, _root)
        break

import torch, numpy as np, torch.nn as nn
from torchdiffeq import odeint
from rnode.models import ConstantODE, TimeDepODE
from rnode.data import make_circles_data


## Configuration

In [None]:
QUICK = False  # Set True for a fast test run (~3 min)

CFG = dict(
    p=24, d=2, T=1.0, seed=42,
    train_epochs_const=2000 if not QUICK else 400,
    train_epochs_tdep=1700 if not QUICK else 400,
    lr=1e-3, n_train=100,
    N_conv=2000, n_real_conv=50 if not QUICK else 10,
    N_boundary=500, n_grid=100, n_real_db=20 if not QUICK else 5,
    h_reps_db=[1, 2, 3],
    data_sizes=[1000, 5000, 10000, 15000, 20000, 25000, 30000] if not QUICK else [1000, 5000],
    n_trials_bench=10 if not QUICK else 3, N_bench=500,
    h_scatter=[0.001, 0.01, 0.1],
    n_real_scatter=30 if not QUICK else 8, n_time_repeats=5 if not QUICK else 2,
    n_real_scheme=30 if not QUICK else 8,
    n_real_pareto=30 if not QUICK else 8, h_pareto=0.005,
    n_real_opth=40 if not QUICK else 8,
    h_S_estimate=0.01,
)

import os
OUT = os.path.join(os.getcwd(), "outputs")
os.makedirs(OUT, exist_ok=True)
torch.manual_seed(CFG["seed"]); np.random.seed(CFG["seed"])

## Train models

In [None]:
Xc, yc = make_circles_data(CFG["n_train"], seed=42)
Xt, yt = make_circles_data(CFG["n_train"], seed=2)
t_tr = torch.linspace(0, CFG["T"], 20)

def train(model, X, y, epochs, label):
    opt = torch.optim.Adam(model.parameters(), lr=CFG["lr"])
    for ep in range(epochs):
        opt.zero_grad()
        traj = odeint(model, X, t_tr)
        loss = nn.MSELoss()(traj[-1, :, :1], y)
        loss.backward(); opt.step()
        if (ep+1) % max(1, epochs//4) == 0:
            print(f"  [{label}] {ep+1}/{epochs}  loss={loss.item():.6f}")
    return model

torch.manual_seed(42)
mc = train(ConstantODE(CFG["p"], CFG["d"]), Xc, yc, CFG["train_epochs_const"], "const")
torch.manual_seed(42)
mt = train(TimeDepODE(CFG["p"], CFG["d"], net_hidden=20), Xt, yt, CFG["train_epochs_tdep"], "tdep")

## Generate figures

In [None]:
from experiments.exp1_plots import *
fig1_trajectories(mc, mt, Xc, yc, Xt, yt, CFG, OUT)

In [None]:
fig2_convergence(mc, mt, Xc, Xt, CFG, OUT)

In [None]:
fig3_decision(mc, mt, CFG, OUT)

In [None]:
fig4_benchmarks(mc, mt, CFG, OUT)

In [None]:
fig5_cost_vs_error(mc, Xc, CFG, OUT)

In [None]:
fig6_scheme_convergence(mc, Xc, CFG, OUT)

In [None]:
fig7_variance(mc, Xc, CFG, OUT)

In [None]:
fig8_pareto(mc, Xc, CFG, OUT)

In [None]:
fig9_optimal_h(mc, Xc, CFG, OUT)

In [None]:
table1_batch_counts(mc, Xc, yc, CFG, OUT)

In [None]:
fig10_error_constant_vs_pimin(mc, Xc, CFG, OUT)

In [None]:
fig11_speedup(mc, Xc, CFG, OUT)

## Save

In [None]:
torch.save(mc.state_dict(), os.path.join(OUT, "model_const.pth"))
torch.save(mt.state_dict(), os.path.join(OUT, "model_tdep.pth"))
print("Done.")