# Experiment 3b — Fixed inner weights (§5.3)

$\dot{x} = W_2(t)\,\sigma(W_1 x + b_1)$ where $W_1, b_1$ are frozen and
only $W_2(t)$ is trained. Includes scaling analysis (accuracy vs $p$).

In [None]:
import os, sys, random

for _root in (os.getcwd(), os.path.abspath(os.path.join(os.getcwd(), ".."))):
    if os.path.isdir(os.path.join(_root, "rnode")) and _root not in sys.path:
        sys.path.insert(0, _root)
        break

import torch, numpy as np
from rnode.models import FixedInnerODE
from rnode.data import make_circles_data
from rnode.utils import compute_accuracy, count_trainable_params, count_total_params
from torchdiffeq import odeint

## Configuration

In [None]:
QUICK = False

CFG = dict(
    hidden_dim=24, T=2.0, seed=42,
    n_train_steps=20, n_eval_steps=500,
    alpha=0.01, beta=0.5,
    n_epochs=1000 if not QUICK else 200,
    lr=1e-3,
    n_batches=3,
    n_realizations=10 if not QUICK else 3,
    n_grid=150, h_reps=[1, 2, 3],
    n_real_boundary=20 if not QUICK else 5,
    # Scaling
    hidden_dims_scaling=[24, 48, 96, 192, 384, 512] if not QUICK else [24, 48, 96],
    n_trials_scaling=5 if not QUICK else 2,
    # Benchmarks
    batch_configs=[2, 3, 4, 6],
    n_benchmark_trials=5 if not QUICK else 2,
    n_realizations_bench=10 if not QUICK else 3,
)

SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)

OUT = os.path.join(os.getcwd(), "outputs")
os.makedirs(OUT, exist_ok=True)

t_train = torch.linspace(0, CFG["T"], CFG["n_train_steps"])
t_eval = torch.linspace(0, CFG["T"], CFG["n_eval_steps"])
step_eval = float(CFG["T"]) / (CFG["n_eval_steps"] - 1)

## Data

In [None]:
X_train, y_train = make_circles_data(100, seed=42)
X_test, y_test = make_circles_data(200, seed=123)

## Reference NODE (fixed inner)

In [None]:
from experiments.exp3b_plots import train_model

torch.manual_seed(SEED)
func_ref = FixedInnerODE(hidden_dim=CFG["hidden_dim"], seed=42)
print(f"Total params: {count_total_params(func_ref)}, "
      f"Trainable: {count_trainable_params(func_ref)}")

losses_ref, time_ref = train_model(func_ref, X_train, y_train, t_train, CFG)
print(f"Time: {time_ref:.1f}s, final loss: {losses_ref[-1]:.4f}")

with torch.no_grad():
    y_ref = odeint(func_ref, X_train, t_eval, method="rk4",
                   options={"step_size": step_eval})
    y_te = odeint(func_ref, X_test, t_eval, method="rk4",
                  options={"step_size": step_eval})
print(f"Train acc: {compute_accuracy(y_ref[-1], y_train):.2%}")
print(f"Test acc:  {compute_accuracy(y_te[-1], y_test):.2%}")

## Scaling analysis

In [None]:
from experiments.exp3b_plots import run_scaling_analysis
scaling = run_scaling_analysis(X_train, y_train, X_test, y_test,
                               t_train, t_eval, CFG, OUT)

## RBM ensemble

In [None]:
from experiments.exp3b_plots import train_rbm_ensemble

print(f"Training RBM ensemble (K={CFG['n_realizations']})...")
rbm = train_rbm_ensemble(X_train, y_train, t_train, t_eval, CFG,
                          inner_seed=42)
print(f"Avg time/realisation: {rbm['avg_time']:.1f}s")

with torch.no_grad():
    te_preds = [odeint(m, X_test, t_eval, method="rk4",
                       options={"step_size": step_eval})[-1]
                for m in rbm["models"]]
y_te_avg = torch.mean(torch.stack(te_preds), 0)
print(f"Ensemble train acc: {compute_accuracy(rbm['y_avg'][-1], y_train):.2%}")
print(f"Ensemble test acc:  {compute_accuracy(y_te_avg, y_test):.2%}")

## Decision boundaries

In [None]:
from experiments.exp3b_plots import compute_decision_boundaries
boundaries, accuracies = compute_decision_boundaries(
    func_ref, X_train, y_train, X_test, y_test,
    t_train, t_eval, CFG, OUT)

## Benchmarks

In [None]:
from experiments.exp3b_plots import run_benchmarks
bench = run_benchmarks(X_train, y_train, X_test, y_test,
                       t_train, t_eval, CFG, OUT)

## Save

In [None]:
torch.save(func_ref.state_dict(), os.path.join(OUT, "ex3b_node.pth"))
print("Done.")