# Training

> Functions for training models

In [None]:
#| default_exp trainer

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from relax.utils import show_doc
show_doc_parser = show_doc

In [None]:
#| export
from relax.import_essentials import *
from relax.data import TabularDataModule
from relax.module import BaseTrainingModule
from relax.logger import TensorboardLogger
from relax.utils import validate_configs
from relax._ckpt_manager import CheckpointManager

In [None]:
#| hide
#| export
class TrainingConfigs(BaseParser):
    """Configurator of 'train_model'."""
    n_epochs: int = Field(
        None, description="Number of epochs."
    )
    batch_size: int = Field(
        None, description="Batch size."
    )
    monitor_metrics: Optional[str] = Field(
        None, description="Monitor metrics used to evaluate the training result after each epoch."
    )
    seed: int = Field(
        42, description="Seed for generating random number."
    )
    log_dir: str = Field(
        "log", description="The name for the directory that holds logged data during training."
    )
    logger_name: str = Field(
        "debug", description="The name for the directory that holds logged data during training under log directory."
    )
    log_on_step: bool = Field(
        False, description="Log the evaluate metrics at the current step."
    )
    max_n_checkpoints: int = Field(
        3, description="Maximum number of checkpoints stored."
    )

    @property
    def PRNGSequence(self):
        return hk.PRNGSequence(self.seed)


In [None]:
show_doc_parser(TrainingConfigs)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/relax/trainer.py#L15){target="_blank" style="float:right; font-size:smaller"}

### TrainingConfigs

>      TrainingConfigs (n_epochs:int=None, batch_size:int=None,
>                       monitor_metrics:Union[str,NoneType]=None, seed:int=42,
>                       log_dir:str='log', logger_name:str='debug',
>                       log_on_step:bool=False, max_n_checkpoints:int=3)

Configurator of 'train_model'.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| n_epochs | int |  | Number of epochs. |
| batch_size | int |  | Batch size. |
| monitor_metrics | typing.Union[str, NoneType] |  | Monitor metrics used to evaluate the training result after each epoch. |
| seed | int | 42 | Seed for generating random number. |
| log_dir | str | log | The name for the directory that holds logged data during training. |
| logger_name | str | debug | The name for the directory that holds logged data during training under log directory. |
| log_on_step | bool | False | Log on step. |
| max_n_checkpoints | int | 3 | Max number of checkpoints. |

In [None]:
#| export
def train_model_with_states(
    training_module: BaseTrainingModule,
    params: hk.Params,
    opt_state: optax.OptState,
    data_module: TabularDataModule,
    t_configs: Union[Dict[str, Any], TrainingConfigs],
) -> Tuple[hk.Params, optax.OptState]:
    t_configs = validate_configs(t_configs, TrainingConfigs)
    keys = t_configs.PRNGSequence
    # define logger
    logger = TensorboardLogger(
        log_dir=t_configs.log_dir,
        name=t_configs.logger_name,
        on_step=t_configs.log_on_step,
    )
    logger.save_hyperparams(t_configs.dict())
    if training_module.hparams:
        logger.save_hyperparams(training_module.hparams)

    training_module.init_logger(logger)
    # define checkpoint manageer
    if t_configs.monitor_metrics is None:
        monitor_metrics = None
    else:
        monitor_metrics = f"{t_configs.monitor_metrics}_epoch"

    ckpt_manager = CheckpointManager(
        log_dir=Path(training_module.logger.log_dir) / "checkpoints",
        monitor_metrics=monitor_metrics,
        max_n_checkpoints=t_configs.max_n_checkpoints,
    )
    # dataloaders
    train_loader = data_module.train_dataloader(t_configs.batch_size)
    val_loader = data_module.val_dataloader(t_configs.batch_size)

    # start training
    for epoch in range(t_configs.n_epochs):
        training_module.logger.on_epoch_started()
        # training
        with tqdm(
            train_loader, unit="batch", leave=epoch == t_configs.n_epochs - 1
        ) as t_loader:
            t_loader.set_description(f"Epoch {epoch}")
            for batch in t_loader:
                x, y = map(device_put, tuple(batch))
                params, opt_state = training_module.training_step(
                    params, opt_state, next(keys), (x, y)
                )
                # logs = training_module.training_step_logs(
                #     params, next(keys), (x, y))
                logs = training_module.logger.get_last_logs()
                t_loader.set_postfix(**logs)
                # logger.log(logs)

        # validation
        for batch in val_loader:
            x, y = map(device_put, tuple(batch))
            logs = training_module.validation_step(params, next(keys), (x, y))
            # logger.log(logs)
        epoch_logs = training_module.logger.on_epoch_finished()
        ckpt_manager.update_checkpoints(params, opt_state, epoch_logs, epoch)

    training_module.logger.close()
    return params, opt_state


In [None]:
#| export
def train_model(
    training_module: BaseTrainingModule, # Training module
    data_module: TabularDataModule, # Data module
    t_configs: Union[Dict[str, Any], TrainingConfigs], # Training configurator
) -> Tuple[hk.Params, optax.OptState]:
    """Train machine learning classifier"""
    t_configs = validate_configs(t_configs, TrainingConfigs)
    keys = t_configs.PRNGSequence 
    params, opt_state = training_module.init_net_opt(data_module, next(keys))
    return train_model_with_states(
        training_module=training_module,
        params=params,
        opt_state=opt_state,
        data_module=data_module,
        t_configs=t_configs,
    )

## Test

In [None]:
from relax.data import TabularDataModule
from relax.module import PredictiveTrainingModule

data_configs = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    "batch_size": 256,
    'sample_frac': 0.1,
    "continous_cols": [
        "age",
        "hours_per_week"
    ],
    "discret_cols": [
        "workclass",
        "education",
        "marital_status",
        "occupation",
        "race",
        "gender"
    ],
}
# dm = 
m_configs = {
    "sizes": [50, 10, 50],
    "dropout_rate": 0.3,
    'lr': 0.003,
}
t_configs = {
    'n_epochs': 10,
    'monitor_metrics': 'val/val_loss',
    'seed': 42,
    "batch_size": 256
}

In [None]:
from relax.module import PredictiveTrainingModule

params, opt_state = train_model(
    PredictiveTrainingModule(m_configs), 
    TabularDataModule(data_configs), 
    t_configs
)