# Best Practices for Using Runners

## Prepare Models and Datasets as Usual

In [1]:
import todd
import torch
import torch.utils.data

[2m[2023-01-10 19:18:16,554 36763:4308452736][loggers.py:110 todd.base.patches.get_logger] DEBUG: logger initialized by lutingwang@wangluting.local[m
[2023-01-10 19:18:16,560 36763:4308452736][patches.py:36 todd.base.patches.<module>] INFO: `ipdb` is installed. Using it for debugging.
[2m[2023-01-10 19:18:18,122 36763:4308452736][loggers.py:110 todd.base.registries.get_logger] DEBUG: logger initialized by lutingwang@wangluting.local[m


Models should be built by users.
The same model can be used by multiple runners, such as a trainer and a validator, simultaneously.

In [2]:
import sys
class Model(todd.Module):

    def __init__(self) -> None:
        super().__init__()
        self._weight = torch.nn.Parameter(torch.tensor(2.0))

    @property
    def weight(self) -> float:
        return self._weight.item()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * self._weight


model = Model()

In contrast to models, datasets are built inside runners.

In [3]:
class Dataset(torch.utils.data.Dataset[int]):

    def __init__(self, n: int) -> None:
        self._data = list(range(1, n + 1))

    def __len__(self) -> int:
        return len(self._data)

    def __getitem__(self, index: int) -> int:
        return self._data[index]

## Define a Mixin for All Runners

In [4]:
class RunnerMixin(todd.utils.BaseRunner):

    def _build_dataloader(
        self,
        config: todd.Config,
    ) -> torch.utils.data.DataLoader:
        dataset = Dataset(**config.pop('dataset'))
        return torch.utils.data.DataLoader(dataset, **config)

    def _run_iter(self, i: int, batch, memo: todd.utils.Memo) -> torch.Tensor:
        y: torch.Tensor = self._model(batch)
        loss = y.sum().abs()
        return loss

`DRY_RUN` is turned on by default when CUDA devices are not available.
To override this setting, manually set `DRY_RUN` to `False`.

In [5]:
todd.utils.BaseRunner.Store.DRY_RUN = False

## Validate

In [6]:
import os
import tempfile
from typing import cast

Define and register the validator.

In [7]:
@todd.utils.RunnerRegistry.register()
class CustomValidator(RunnerMixin, todd.utils.Validator):
    pass

Define the validator config. 
`config` will be reused by trainers.

In [8]:
config = todd.Config(
    model=model,
    log=dict(interval=5),
    load_state_dict=dict(model=dict(strict=False)),
    state_dict=dict(model=dict()),
)
validator = todd.Config(
    type='CustomValidator',
    dataloader=dict(batch_size=1, dataset=dict(n=20)),
)

Build and run the validator.
Logs will be saved to the working directory.

In [9]:
with tempfile.TemporaryDirectory() as work_dir:
    validator.name = work_dir
    runner = todd.utils.RunnerRegistry.build(validator, config)
    cast(CustomValidator, runner).run()
    print(os.listdir(work_dir))

[2m[2023-01-10 19:18:22,090 36763:4308452736][loggers.py:110 todd.utils.runners.4407437312.get_logger] DEBUG: logger initialized by lutingwang@wangluting.local[m
[2023-01-10 19:18:22,092 36763:4308452736][runners.py:166 todd.utils.runners.4407437312._run] INFO: Iter [5/20] Loss 10.000
[2023-01-10 19:18:22,093 36763:4308452736][runners.py:166 todd.utils.runners.4407437312._run] INFO: Iter [10/20] Loss 20.000
[2023-01-10 19:18:22,095 36763:4308452736][runners.py:166 todd.utils.runners.4407437312._run] INFO: Iter [15/20] Loss 30.000
[2023-01-10 19:18:22,097 36763:4308452736][runners.py:166 todd.utils.runners.4407437312._run] INFO: Iter [20/20] Loss 40.000


['20230110T191822f.log']


## Train

In [10]:
trainer = dict(
    dataloader=dict(batch_size=2, dataset=dict(n=67)),
    optimizer=dict(type='SGD', lr=0.01),
    load_state_dict=dict(optimizer=dict()),
    state_dict=dict(optimizer=dict(), interval=20),
)

### By Iter

In [11]:
@todd.utils.RunnerRegistry.register()
class CustomIterBasedTrainer(RunnerMixin, todd.utils.IterBasedTrainer):

    def _before_run_iter_log(
        self,
        i: int,
        batch,
        memo: todd.utils.Memo,
    ) -> str | None:
        info = super()._before_run_iter_log(i, batch, memo)
        if info is None:
            info = ''
        model: Model = self.model
        info += f" Weight {model.weight:.3f}"
        info += f" Batch {batch}"
        return info

By default, `_before_run_iter_log` returns `None`, meaning that no message will be printed.

In [12]:
with tempfile.TemporaryDirectory() as work_dir:
    iter_based_trainer = trainer.copy()
    iter_based_trainer.update(
        type='CustomIterBasedTrainer',
        name=work_dir,
        iters=53,
    )
    runner = todd.utils.RunnerRegistry.build(iter_based_trainer, config)
    cast(CustomIterBasedTrainer, runner).run()
    print(os.listdir(work_dir))

[2m[2023-01-10 19:18:22,189 36763:4308452736][loggers.py:110 todd.utils.runners.4407343136.get_logger] DEBUG: logger initialized by lutingwang@wangluting.local[m
[2023-01-10 19:18:22,191 36763:4308452736][runners.py:154 todd.utils.runners.4407343136._run] INFO:  Weight 1.640 Batch tensor([ 9, 10])
[2023-01-10 19:18:22,192 36763:4308452736][runners.py:166 todd.utils.runners.4407343136._run] INFO: Iter [6/53] Loss 31.160
[2023-01-10 19:18:22,194 36763:4308452736][runners.py:154 todd.utils.runners.4407343136._run] INFO:  Weight 0.290 Batch tensor([19, 20])
[2023-01-10 19:18:22,195 36763:4308452736][runners.py:166 todd.utils.runners.4407343136._run] INFO: Iter [11/53] Loss 11.310
[2023-01-10 19:18:22,198 36763:4308452736][runners.py:154 todd.utils.runners.4407343136._run] INFO:  Weight -0.180 Batch tensor([29, 30])
[2023-01-10 19:18:22,199 36763:4308452736][runners.py:166 todd.utils.runners.4407343136._run] INFO: Iter [16/53] Loss 10.620
[2023-01-10 19:18:22,201 36763:4308452736][runners

['iter_20.pth', 'iter_40.pth', 'latest.pth', '20230110T191822f.log']


Trainers increment `todd.Store.ITER` to keep track of the training progress.
If multiple trainers are to be run, `todd.Store.ITER` must be manually reset to zero.

### By Epoch

In [13]:
@todd.utils.RunnerRegistry.register()
class CustomEpochBasedTrainer(RunnerMixin, todd.utils.EpochBasedTrainer):
    pass

In [14]:
with tempfile.TemporaryDirectory() as work_dir:
    epoch_based_trainer = trainer.copy()
    epoch_based_trainer.update(
        type='CustomEpochBasedTrainer',
        name=work_dir,
        epochs=3,
    )
    runner = todd.utils.RunnerRegistry.build(epoch_based_trainer, config)
    cast(CustomEpochBasedTrainer, runner).run()
    print(os.listdir(work_dir))

[2m[2023-01-10 19:18:22,284 36763:4308452736][loggers.py:110 todd.utils.runners.4407439232.get_logger] DEBUG: logger initialized by lutingwang@wangluting.local[m
[2023-01-10 19:18:22,285 36763:4308452736][runners.py:423 todd.utils.runners.4407439232._run] INFO: Epoch [1/3] beginning
[2023-01-10 19:18:22,287 36763:4308452736][runners.py:166 todd.utils.runners.4407439232._run] INFO: Iter [5/34] Loss 1.330
[2023-01-10 19:18:22,289 36763:4308452736][runners.py:166 todd.utils.runners.4407439232._run] INFO: Iter [10/34] Loss 7.800
[2023-01-10 19:18:22,291 36763:4308452736][runners.py:166 todd.utils.runners.4407439232._run] INFO: Iter [15/34] Loss 15.930
[2023-01-10 19:18:22,292 36763:4308452736][runners.py:166 todd.utils.runners.4407439232._run] INFO: Iter [20/34] Loss 31.600
[2023-01-10 19:18:22,294 36763:4308452736][runners.py:166 todd.utils.runners.4407439232._run] INFO: Iter [25/34] Loss 46.530
[2023-01-10 19:18:22,296 36763:4308452736][runners.py:166 todd.utils.runners.4407439232._run

['epoch_1.pth', 'epoch_2.pth', 'epoch_3.pth', '20230110T191822f.log']


## Dry Run

In [15]:
todd.utils.BaseRunner.Store.DRY_RUN = True

If `DRY_RUN` is enabled, the runner will stop upon the first log message.

In [16]:
with tempfile.TemporaryDirectory() as work_dir:
    validator.name = work_dir
    runner = todd.utils.RunnerRegistry.build(validator, config)
    cast(CustomValidator, runner).run()

[2m[2023-01-10 19:18:22,395 36763:4308452736][loggers.py:110 todd.utils.runners.6438374992.get_logger] DEBUG: logger initialized by lutingwang@wangluting.local[m
[2023-01-10 19:18:22,397 36763:4308452736][runners.py:166 todd.utils.runners.6438374992._run] INFO: Iter [5/20] Loss 0.050


## State Dicts