# Training

> Functions for training models

In [None]:
#| default_exp legacy.trainer

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
#| export
from __future__ import annotations
from relax.legacy.import_essentials import *
from relax.legacy.module import BaseTrainingModule, PredictiveTrainingModule
from relax.legacy.logger import Logger
from relax.legacy.utils import validate_configs, load_json
from relax.legacy.ckpt_manager import CheckpointManager, load_checkpoint
from relax.data_module import DataModule
from urllib.request import urlretrieve
from keras.src.trainers.epoch_iterator import EpochIterator


In [None]:
#| export
class TrainingConfigs(BaseParser):
    """Configurator of `train_model`."""
    
    n_epochs: int = Field(
        description="Number of epochs."
    )
    batch_size: int = Field(
        description="Batch size."
    )
    monitor_metrics: Optional[str] = Field(
        None, description="Monitor metrics used to evaluate the training result after each epoch."
    )
    seed: int = Field(
        42, description="Seed for generating random number."
    )
    log_dir: str = Field(
        "log", description="The name for the directory that holds logged data during training."
    )
    logger_name: str = Field(
        "debug", description="The name for the directory that holds logged data during training under log directory."
    )
    log_on_step: bool = Field(
        False, description="Log the evaluate metrics at the current step."
    )
    max_n_checkpoints: int = Field(
        3, description="Maximum number of checkpoints stored."
    )

    @property
    def PRNGSequence(self):
        return hk.PRNGSequence(self.seed)


In [None]:
#| export
def train_model_with_states(
    training_module: BaseTrainingModule,
    params: hk.Params,
    opt_state: optax.OptState,
    data_module: DataModule,
    t_configs: Dict[str, Any] | TrainingConfigs,
) -> Tuple[hk.Params, optax.OptState]:
    """Train models with `params` and `opt_state`."""

    t_configs = validate_configs(t_configs, TrainingConfigs)
    keys = t_configs.PRNGSequence
    # define logger
    logger = Logger(
        log_dir=t_configs.log_dir,
        name=t_configs.logger_name,
        on_step=t_configs.log_on_step,
    )
    logger.save_hyperparams(t_configs.dict())
    if hasattr(training_module, "hparams") and training_module.hparams is not None:
        logger.save_hyperparams(training_module.hparams)

    training_module.init_logger(logger)
    # define checkpoint manageer
    if t_configs.monitor_metrics is None:
        monitor_metrics = None
    else:
        monitor_metrics = f"{t_configs.monitor_metrics}_epoch"

    ckpt_manager = CheckpointManager(
        log_dir=Path(training_module.logger.log_dir) / "checkpoints",
        monitor_metrics=monitor_metrics,
        max_n_checkpoints=t_configs.max_n_checkpoints,
    )
    # dataloaders
    # train_loader = jdl.DataLoader(jdl.ArrayDataset(*data_module['train']), backend='jax', batch_size=t_configs.batch_size, shuffle=True) 
    epoch_iterator = EpochIterator(*data_module['train'], batch_size=t_configs.batch_size, shuffle=True)
    val_epoch_iterator = EpochIterator(*data_module['test'], batch_size=t_configs.batch_size, shuffle=False)

    @jax.jit
    def train_step(params, opt_state, key, batch):
        return training_module.training_step(params, opt_state, key, batch)
    
    # start training
    for epoch in range(t_configs.n_epochs):
        training_module.logger.on_epoch_started()
        # for step, batch in epoch_iterator.enumerate_epoch('np'):
        with tqdm(
            epoch_iterator.enumerate_epoch('np'), 
            unit="batch", 
            leave=epoch == t_configs.n_epochs - 1,
            total=epoch_iterator.num_batches
        ) as t_loader:
            t_loader.set_description(f"Epoch {epoch}")
            for step, batch in t_loader:
                x, y = batch[0]
                logs, (params, opt_state) = train_step(params, opt_state, next(keys), (x, y))
                # TODO: tqdm becomes the bottleneck
                t_loader.set_postfix(**logs)
        
        # validation
        for step, batch in val_epoch_iterator.enumerate_epoch('np'):
            x, y = batch[0]
            logs = training_module.validation_step(params, next(keys), (x, y))
            # logger.log(logs)
        epoch_logs = training_module.logger.on_epoch_finished()
        ckpt_manager.update_checkpoints(params, opt_state, epoch_logs, epoch)

    training_module.logger.close()
    return params, opt_state


In [None]:
#| export
def train_model(
    training_module: BaseTrainingModule, # Training module
    data_module: DataModule, # Data module
    batch_size=128, # Batch size
    epochs=1, # Number of epochs
    **fit_kwargs # Positional arguments for `keras.Model.fit`
) -> Tuple[hk.Params, optax.OptState]:
    """Train models."""
    t_configs = TrainingConfigs(
        n_epochs=epochs,
        batch_size=batch_size,
        **fit_kwargs
    )
    keys = t_configs.PRNGSequence 
    params, opt_state = training_module.init_net_opt(data_module, next(keys))
    return train_model_with_states(
        training_module=training_module,
        params=params,
        opt_state=opt_state,
        data_module=data_module,
        t_configs=t_configs,
    )

## Examples

A siimple example to train a predictive model.

In [None]:
from relax.legacy.module import PredictiveTrainingModule, PredictiveModelConfigs
from relax.data_module import load_data

In [None]:
datamodule = load_data('adult')

params, opt_state = train_model(
    PredictiveTrainingModule({'sizes': [64, 32, 16], 'lr': 0.003}), 
    datamodule,
)

Epoch 0: 100%|██████████| 191/191 [00:01<00:00, 102.62batch/s, train/train_loss=0.06518286] 


In [None]:
from relax.ml_model import MLModule

In [None]:
model = MLModule()
model.train(datamodule, batch_size=128, epochs=1)

[1m191/191[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - accuracy: 0.7148 - loss: 0.5726


<relax.ml_model.MLModule>