# Quickstart: Training disRNN with ModelTrainee

This notebook demonstrates how to:
1. Define a disRNN model and training session using `ModelTrainee`
2. Persist the training configuration to a SQLite database
3. Train the model with automatic checkpointing
4. Resume training after interruption (via `GracefulKiller`)
5. Inspect training results

In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

from pathlib import Path
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session

from disRNN_MP.rnn.train_db import Base, ModelTrainee, trainingSession

## Create SQLite database

ModelTrainee uses SQLAlchemy to persist model definitions, training sessions, and checkpoints.

In [2]:
engine = create_engine('sqlite:///quickstart_training.db', echo=False)
Base.metadata.create_all(engine)

## Define model and training session

We use Hydra-style `dry_*` configs (dictionaries with `_target_` keys) to specify the model, optimizer, loss function, and dataset, which allows easily reloading the trained models in a new python session to continue training or for analysis. 

In [3]:
# Resolve data path relative to the repo root
data_path = Path('../data/mp_beh_m18_500completed.npy').resolve().as_posix()

# The ModelTrainee class represent an entry point to a model: it points to a architecture definition that will be initialized with certain hyperparameters, also hold references to all planned training sessions and trained snapshots.
mt = ModelTrainee(
    dry_model={
        '_target_': 'disRNN_MP.rnn.disrnn_def.make_transformed_disrnn',
        'latent_size': 5,
        'update_mlp_shape': [3, 3],
        'choice_mlp_shape': [2],
        'target_size': 2,
        'eval_mode': False,
    },
)

# A trainingSession represent one stage of training that can have its own optimizer, training step function, loss function, and dataset.
se = trainingSession(
    dry_optimizer={
        '_target_': 'optax.adam',
        'learning_rate': 1e-3,
    },
    dry_make_train_step={
        '_target_': 'disRNN_MP.rnn.disrnn_def.make_train_step',
        '_partial_': True,
        'penalty_scale': 0,
    },
    dry_make_param_metric={
        '_target_': 'disRNN_MP.rnn.disrnn_def.make_param_metric_expLL',
        '_partial_': True,
    },
    dry_datasets={
        '_target_': 'disRNN_MP.dataset.train_test_datasets',
        'dat_or_path': data_path,
        'n_sess_sample': 42,
        'seed': 5,
    },
    n_step=200,
    steps_per_block=10,
)

# the `sessions` attribute of ModelTrainee is a ordered list that plans the training with training sessions in a specific order
mt.sessions = [se]

## Commit to database

In [4]:
with Session(engine) as sess:
    sess.add(mt)
    sess.commit()

## Train the model

Load the `ModelTrainee` from the database and call `.train()`. Training progress is
automatically checkpointed to the database every `steps_per_block` steps.

In [5]:
sess = Session(engine, expire_on_commit=False)
mt = sess.execute(select(ModelTrainee).where(ModelTrainee.id == 1)).scalar_one()
ret = mt.train(sess, 'worker1')

  self.records.append(rec)


step 1 is done with loss: 1.5024e+04 (Time: 9.1s)
step 11 is done with loss: 1.4651e+04 (Time: 10.4s)
step 21 is done with loss: 1.4582e+04 (Time: 11.7s)
step 31 is done with loss: 1.4569e+04 (Time: 13.0s)
step 41 is done with loss: 1.4539e+04 (Time: 14.2s)
step 51 is done with loss: 1.4522e+04 (Time: 15.5s)
step 61 is done with loss: 1.4503e+04 (Time: 16.8s)
step 71 is done with loss: 1.4490e+04 (Time: 18.1s)
step 81 is done with loss: 1.4473e+04 (Time: 19.4s)
step 91 is done with loss: 1.4461e+04 (Time: 20.7s)
step 101 is done with loss: 1.4449e+04 (Time: 22.0s)
step 111 is done with loss: 1.4435e+04 (Time: 23.3s)
[2026-02-19T18:57:10.158]: received signal 2
[2026-02-19T18:57:10.161]: terminated at step 112


## Resumable training

If training is interrupted (e.g., via Ctrl-C), `GracefulKiller` catches the signal and
saves the current state to the database. You can then reload and call `.train()` again
to resume from the last checkpoint.

Try interrupting the cell above with Ctrl-C during training, then run the cell below
to resume.

In [6]:
sess = Session(engine, expire_on_commit=False)
mt = sess.execute(select(ModelTrainee).where(ModelTrainee.id == 1)).scalar_one()
ret = mt.train(sess, 'worker1')  # resumes from last saved step

step 121 is done with loss: 1.4423e+04 (Time: 8.8s)
step 131 is done with loss: 1.4406e+04 (Time: 10.1s)
step 141 is done with loss: 1.4388e+04 (Time: 11.4s)
step 151 is done with loss: 1.4367e+04 (Time: 12.7s)
step 161 is done with loss: 1.4341e+04 (Time: 14.0s)
[2026-02-19T18:58:07.312]: received signal 2
[2026-02-19T18:58:07.315]: terminated at step 168


## Inspect training results

In [7]:
for rec in mt.records:
    print(f"step {rec.step}: train={rec.train_metric:.4f}, test={rec.test_metric:.4f}")

step 1: train=0.4902, test=0.4897
step 11: train=0.4981, test=0.4976
step 21: train=0.4993, test=0.4988
step 31: train=0.4998, test=0.4993
step 41: train=0.5004, test=0.4999
step 51: train=0.5008, test=0.5003
step 61: train=0.5013, test=0.5008
step 71: train=0.5016, test=0.5011
step 81: train=0.5020, test=0.5015
step 91: train=0.5023, test=0.5018
step 101: train=0.5026, test=0.5021
step 111: train=0.5029, test=0.5024
step 121: train=0.5032, test=0.5027
step 131: train=0.5036, test=0.5031
step 141: train=0.5040, test=0.5035
step 151: train=0.5046, test=0.5041
step 161: train=0.5052, test=0.5047
