# Experiment 3a — Training with RBM (§5.3)

Both layers are trainable. Each rNODE realisation uses a **fixed** batch
schedule (structured pruning). Predictions are ensemble-averaged over K
realisations.

**Contents:** Reference NODE, RBM ensemble, decision boundaries, benchmarks.

In [None]:
import os, sys, random

for _root in (os.getcwd(), os.path.abspath(os.path.join(os.getcwd(), ".."))):
    if os.path.isdir(os.path.join(_root, "rnode")) and _root not in sys.path:
        sys.path.insert(0, _root)
        break

import torch, numpy as np
from rnode.models import TimeDepODE_ELU
from rnode.data import make_circles_data
from rnode.utils import compute_accuracy
from torchdiffeq import odeint

## Configuration

In [None]:
QUICK = False

CFG = dict(
    hidden_dim=24, T=2.0, seed=42,
    n_train_steps=20, n_eval_steps=500,
    alpha=0.01, beta=0.5,
    n_epochs=1000 if not QUICK else 200,
    lr=1e-3,
    # RBM ensemble
    n_batches=3,
    n_realizations=10 if not QUICK else 3,
    # Decision boundaries
    n_grid=150, h_reps=[1, 2, 3],
    n_real_boundary=20 if not QUICK else 5,
    # Benchmarks
    batch_configs=[2, 3, 4, 6],
    n_benchmark_trials=5 if not QUICK else 2,
    n_realizations_bench=10 if not QUICK else 3,
)

SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)

OUT = os.path.join(os.getcwd(), "outputs")
os.makedirs(OUT, exist_ok=True)

t_train = torch.linspace(0, CFG["T"], CFG["n_train_steps"])
t_eval = torch.linspace(0, CFG["T"], CFG["n_eval_steps"])
step_eval = float(CFG["T"]) / (CFG["n_eval_steps"] - 1)

## Data

In [None]:
X_train, y_train = make_circles_data(100, seed=42)
X_test, y_test = make_circles_data(200, seed=123)
print(f"Train: {len(X_train)}, Test: {len(X_test)}")

## Reference NODE

In [None]:
from experiments.exp3a_plots import train_model, plot_dynamics

torch.manual_seed(SEED)
func_ref = TimeDepODE_ELU(hidden_dim=CFG["hidden_dim"])
losses_ref, time_ref = train_model(func_ref, X_train, y_train, t_train, CFG)
print(f"Training time: {time_ref:.1f}s, final loss: {losses_ref[-1]:.4f}")

with torch.no_grad():
    y_ref = odeint(func_ref, X_train, t_eval, method="rk4",
                   options={"step_size": step_eval})
print(f"Train acc: {compute_accuracy(y_ref[-1], y_train):.2%}")

plot_dynamics(y_ref, y_train, "NODE", OUT, "ex3a_node_dynamics")

## RBM ensemble

In [None]:
from experiments.exp3a_plots import train_rbm_ensemble

print(f"Training RBM ensemble ({CFG['n_batches']} batches, K={CFG['n_realizations']})...")
rbm = train_rbm_ensemble(X_train, y_train, t_train, t_eval, CFG)
print(f"Avg time/realisation: {rbm['avg_time']:.1f}s")

acc_rbm = compute_accuracy(rbm["y_avg"][-1], y_train)
print(f"Ensemble train acc: {acc_rbm:.2%}")

plot_dynamics(rbm["y_avg"], y_train, "rNODE", OUT, "ex3a_rnode_dynamics")

## Decision boundaries

In [None]:
from experiments.exp3a_plots import plot_decision_boundaries
from rnode.data import make_grid

n_grid = CFG["n_grid"]
xx, yy, grid = make_grid(n_points=n_grid)

boundaries, accuracies = {}, {}

# Reference
print("Computing NODE boundary...")
with torch.no_grad():
    y_g = odeint(func_ref, grid, t_eval, method="rk4",
                 options={"step_size": step_eval})
boundaries["NODE"] = y_g[-1, :, 0].numpy()
with torch.no_grad():
    y_te = odeint(func_ref, X_test, t_eval, method="rk4",
                  options={"step_size": step_eval})
accuracies["NODE"] = compute_accuracy(y_te[-1], y_test)

# rNODE at different h
for rep in CFG["h_reps"]:
    print(f"Computing rNODE boundary (h={rep}*dt)...")
    cfg_bnd = {**CFG, "n_realizations": CFG["n_real_boundary"]}
    rbm_h = train_rbm_ensemble(X_train, y_train, t_train, t_eval,
                                cfg_bnd, rep=rep, verbose=False)
    grid_preds, test_preds = [], []
    with torch.no_grad():
        for m in rbm_h["models"]:
            grid_preds.append(odeint(m, grid, t_eval, method="rk4",
                                     options={"step_size": step_eval})[-1, :, 0])
            test_preds.append(odeint(m, X_test, t_eval, method="rk4",
                                     options={"step_size": step_eval})[-1])
    boundaries[f"h={rep}"] = torch.mean(torch.stack(grid_preds), 0).numpy()
    accuracies[f"h={rep}"] = compute_accuracy(
        torch.mean(torch.stack(test_preds), 0), y_test)
    print(f"  acc = {accuracies[f'h={rep}']:.2%}")

plot_decision_boundaries(boundaries, accuracies, xx, yy,
                         X_test, y_test, OUT, "ex3a_decision_boundaries")

## Benchmarks

In [None]:
from experiments.exp3a_plots import plot_benchmark_table

# This cell takes a long time — skip with QUICK
benchmark_results = {}

print("Benchmarking NODE...")
from experiments.exp3a_plots import train_model as tm3a
node_trials = []
for trial in range(CFG["n_benchmark_trials"]):
    torch.manual_seed(SEED + trial)
    model = TimeDepODE_ELU(CFG["hidden_dim"])
    losses, elapsed = tm3a(model, X_train, y_train, t_train, CFG, verbose=False)
    with torch.no_grad():
        y_tr = odeint(model, X_train, t_eval, method="rk4",
                      options={"step_size": step_eval})
        y_te = odeint(model, X_test, t_eval, method="rk4",
                      options={"step_size": step_eval})
    node_trials.append({"time": elapsed, "loss": losses[-1],
                         "train_acc": compute_accuracy(y_tr[-1], y_train),
                         "test_acc": compute_accuracy(y_te[-1], y_test)})
benchmark_results["NODE"] = {
    "train_acc_mean": np.mean([t["train_acc"] for t in node_trials]),
    "test_acc_mean": np.mean([t["test_acc"] for t in node_trials]),
    "time_mean": np.mean([t["time"] for t in node_trials]),
}
print(f"  test={benchmark_results['NODE']['test_acc_mean']:.2%}")

for nb in CFG["batch_configs"]:
    print(f"Benchmarking RBM ({nb} batches)...")
    cfg_b = {**CFG, "n_batches": nb, "n_realizations": CFG["n_realizations_bench"]}
    trials = []
    for trial in range(CFG["n_benchmark_trials"]):
        rbm_b = train_rbm_ensemble(X_train, y_train, t_train, t_eval,
                                    cfg_b, verbose=False)
        with torch.no_grad():
            te_preds = [odeint(m, X_test, t_eval, method="rk4",
                               options={"step_size": step_eval})[-1]
                        for m in rbm_b["models"]]
            tr_preds = [odeint(m, X_train, t_eval, method="rk4",
                               options={"step_size": step_eval})[-1]
                        for m in rbm_b["models"]]
        trials.append({
            "train_acc": compute_accuracy(torch.mean(torch.stack(tr_preds), 0), y_train),
            "test_acc": compute_accuracy(torch.mean(torch.stack(te_preds), 0), y_test),
            "time": rbm_b["avg_time"],
        })
    benchmark_results[f"{nb}_batches"] = {
        "train_acc_mean": np.mean([t["train_acc"] for t in trials]),
        "test_acc_mean": np.mean([t["test_acc"] for t in trials]),
        "time_mean": np.mean([t["time"] for t in trials]),
    }
    print(f"  test={benchmark_results[f'{nb}_batches']['test_acc_mean']:.2%}")

plot_benchmark_table(benchmark_results, CFG["batch_configs"], OUT, "ex3a_benchmark")

## Save

In [None]:
torch.save(func_ref.state_dict(), os.path.join(OUT, "ex3a_node.pth"))
print("Done.")