# Training

> Functions for training models

In [None]:
# default_exp train

In [None]:
# hide
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
# export
from cfnet.import_essentials import *
from cfnet.datasets import TabularDataModule
from cfnet.training_module import BaseTrainingModule
from cfnet.logger import TensorboardLogger
from cfnet.utils import validate_configs
from cfnet._ckpt_manager import CheckpointManager

In [None]:
# exporti
class TrainingConfigs(BaseParser):
    n_epochs: int
    monitor_metrics: str
    seed: int = 42
    log_dir: str = "log"
    logger_name: str = "debug"
    log_on_step: bool = False
    max_n_checkpoints: int = 3

    @property
    def PRNGSequence(self): return hk.PRNGSequence(self.seed)

In [None]:
# export
def train_model_with_states(
    training_module: BaseTrainingModule,
    params: hk.Params,
    opt_state: optax.OptState,
    data_module: TabularDataModule,
    t_configs: Union[Dict[str, Any], TrainingConfigs]
) -> Tuple[hk.Params, optax.OptState]:
    t_configs = validate_configs(t_configs, TrainingConfigs)
    keys = t_configs.PRNGSequence
    # define logger
    logger = TensorboardLogger(
        log_dir=t_configs.log_dir, name=t_configs.logger_name, on_step=t_configs.log_on_step
    )
    logger.save_hyperparams(t_configs.dict())
    if training_module.hparams:
        logger.save_hyperparams(training_module.hparams)

    training_module.init_logger(logger)
    # define checkpoint manageer
    ckpt_manager = CheckpointManager(
        log_dir=Path(training_module.logger.log_dir) / 'checkpoints',
        monitor_metrics=f"{t_configs.monitor_metrics}_epoch",
        max_n_checkpoints=t_configs.max_n_checkpoints
    )
    # dataloaders
    train_loader = data_module.train_dataloader()
    val_loader = data_module.val_dataloader()

    # start training
    for epoch in range(t_configs.n_epochs):
        training_module.logger.on_epoch_started()
        # training
        with tqdm(train_loader, unit='batch', leave=epoch==t_configs.n_epochs-1) as t_loader:
            t_loader.set_description(f"Epoch {epoch}")
            for batch in t_loader:
                x, y = map(device_put, tuple(batch))
                params, opt_state = training_module.training_step(
                    params, opt_state, next(keys), (x, y))
                # logs = training_module.training_step_logs(
                #     params, next(keys), (x, y))
                logs = training_module.logger.get_last_logs()
                t_loader.set_postfix(**logs)
                # logger.log(logs)

        # validation
        for batch in val_loader:
            x, y = map(device_put, tuple(batch))
            logs = training_module.validation_step(params, next(keys), (x, y))
            # logger.log(logs)
        epoch_logs = training_module.logger.on_epoch_finished()
        ckpt_manager.update_checkpoints(params, opt_state, epoch_logs, epoch)

    training_module.logger.close()
    return params, opt_state

In [None]:
# export
def train_model(
    training_module: BaseTrainingModule,
    data_module: TabularDataModule,
    t_configs: Union[Dict[str, Any], TrainingConfigs]
) -> Tuple[hk.Params, optax.OptState]:
    t_configs = validate_configs(t_configs, TrainingConfigs)
    keys = t_configs.PRNGSequence
    params, opt_state = training_module.init_net_opt(data_module, next(keys))
    return train_model_with_states(
        training_module=training_module, params=params,
        opt_state=opt_state, data_module=data_module, t_configs=t_configs
    )

## Test

In [None]:
from cfnet.datasets import TabularDataModule
from cfnet.training_module import PredictiveTrainingModule

data_configs = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    "batch_size": 256,
    'sample_frac': 0.1,
    "continous_cols": [
        "age",
        "hours_per_week"
    ],
    "discret_cols": [
        "workclass",
        "education",
        "marital_status",
        "occupation",
        "race",
        "gender"
    ],
}
# dm = 
m_configs = {
    "sizes": [50, 10, 50],
    "dropout_rate": 0.3,
    'lr': 0.003,
}
t_configs = {
    'n_epochs': 10,
    'monitor_metrics': 'val/val_loss'
}

In [None]:
from cfnet.training_module import PredictiveTrainingModule

params, opt_state = train_model(
    PredictiveTrainingModule(m_configs), 
    TabularDataModule(data_configs), 
    t_configs
)

[autoreload of cfnet.nets failed: Traceback (most recent call last):
  File "/home/birk/miniconda3/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/home/birk/miniconda3/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 410, in superreload
    update_generic(old_obj, new_obj)
  File "/home/birk/miniconda3/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/home/birk/miniconda3/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 317, in update_class
    update_instances(old, new)
  File "/home/birk/miniconda3/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 280, in update_instances
    ref.__class__ = new
  File "pydantic/main.py", line 357, in pydantic.main.BaseModel.__setattr__
ValueError: "PredictiveModelConfigs" object has no field "__class__"
]
[autoreload of cfnet.training_module 

KeyboardInterrupt: 