# Training

> Functions for training models

In [None]:
#| default_exp trainer

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from relax.utils import show_doc
show_doc_parser = show_doc

In [None]:
#| export
from __future__ import annotations
from relax.import_essentials import *
from relax.data import TabularDataModule, load_data
from relax.data.module import DEFAULT_DATA_CONFIGS
from relax.module import BaseTrainingModule, PredictiveTrainingModule
from relax.logger import TensorboardLogger
from relax.utils import validate_configs, load_json
from relax._ckpt_manager import CheckpointManager, load_checkpoint
from urllib.request import urlretrieve

In [None]:
#| export
class TrainingConfigs(BaseParser):
    """Configurator of `train_model`."""
    
    n_epochs: int = Field(
        description="Number of epochs."
    )
    batch_size: int = Field(
        description="Batch size."
    )
    monitor_metrics: Optional[str] = Field(
        None, description="Monitor metrics used to evaluate the training result after each epoch."
    )
    seed: int = Field(
        42, description="Seed for generating random number."
    )
    log_dir: str = Field(
        "log", description="The name for the directory that holds logged data during training."
    )
    logger_name: str = Field(
        "debug", description="The name for the directory that holds logged data during training under log directory."
    )
    log_on_step: bool = Field(
        False, description="Log the evaluate metrics at the current step."
    )
    max_n_checkpoints: int = Field(
        3, description="Maximum number of checkpoints stored."
    )

    @property
    def PRNGSequence(self):
        return hk.PRNGSequence(self.seed)


In [None]:
#| export
def train_model_with_states(
    training_module: BaseTrainingModule,
    params: hk.Params,
    opt_state: optax.OptState,
    data_module: TabularDataModule,
    t_configs: Dict[str, Any] | TrainingConfigs,
) -> Tuple[hk.Params, optax.OptState]:
    """Train models with `params` and `opt_state`."""

    t_configs = validate_configs(t_configs, TrainingConfigs)
    keys = t_configs.PRNGSequence
    # define logger
    logger = TensorboardLogger(
        log_dir=t_configs.log_dir,
        name=t_configs.logger_name,
        on_step=t_configs.log_on_step,
    )
    logger.save_hyperparams(t_configs.dict())
    if hasattr(training_module, "hparams") and training_module.hparams is not None:
        logger.save_hyperparams(training_module.hparams)

    training_module.init_logger(logger)
    # define checkpoint manageer
    if t_configs.monitor_metrics is None:
        monitor_metrics = None
    else:
        monitor_metrics = f"{t_configs.monitor_metrics}_epoch"

    ckpt_manager = CheckpointManager(
        log_dir=Path(training_module.logger.log_dir) / "checkpoints",
        monitor_metrics=monitor_metrics,
        max_n_checkpoints=t_configs.max_n_checkpoints,
    )
    # dataloaders
    train_loader = data_module.train_dataloader(t_configs.batch_size)
    val_loader = data_module.val_dataloader(t_configs.batch_size)

    # start training
    for epoch in range(t_configs.n_epochs):
        training_module.logger.on_epoch_started()
        # training
        with tqdm(
            train_loader, unit="batch", leave=epoch == t_configs.n_epochs - 1
        ) as t_loader:
            t_loader.set_description(f"Epoch {epoch}")
            for batch in t_loader:
                x, y = map(device_put, tuple(batch))
                params, opt_state = training_module.training_step(
                    params, opt_state, next(keys), (x, y)
                )
                # logs = training_module.training_step_logs(
                #     params, next(keys), (x, y))
                logs = training_module.logger.get_last_logs()
                t_loader.set_postfix(**logs)
                # logger.log(logs)

        # validation
        for batch in val_loader:
            x, y = map(device_put, tuple(batch))
            logs = training_module.validation_step(params, next(keys), (x, y))
            # logger.log(logs)
        epoch_logs = training_module.logger.on_epoch_finished()
        ckpt_manager.update_checkpoints(params, opt_state, epoch_logs, epoch)

    training_module.logger.close()
    return params, opt_state


In [None]:
#| export
def train_model(
    training_module: BaseTrainingModule, # Training module
    data_module: TabularDataModule, # Data module
    t_configs: Dict[str, Any] | TrainingConfigs, # Training configurator
) -> Tuple[hk.Params, optax.OptState]:
    """Train models."""
    
    t_configs = validate_configs(t_configs, TrainingConfigs)
    keys = t_configs.PRNGSequence 
    params, opt_state = training_module.init_net_opt(data_module, next(keys))
    return train_model_with_states(
        training_module=training_module,
        params=params,
        opt_state=opt_state,
        data_module=data_module,
        t_configs=t_configs,
    )

In [None]:
#| export module
def load_pred_model(
    data_name: str # The name of data
    ) -> Tuple[hk.Params, PredictiveTrainingModule]:
    """High-level util function for loading trained model."""

    # validate data name
    if data_name not in DEFAULT_DATA_CONFIGS.keys():
        raise ValueError(f'`data_name` must be one of {DEFAULT_DATA_CONFIGS.keys()}, '
            f'but got data_name={data_name}.')

    # Download model
    download_model(data_name)

    # Fetch the sizes and lr from the configs file
    data_dir = Path(os.getcwd()) / "cf_data" / data_name 
    mlp_configs = load_json(data_dir / "configs.json" )['mlp_configs']
    sizes = mlp_configs["sizes"]
    lr = mlp_configs["lr"]

    module = PredictiveTrainingModule({'sizes': sizes, 'lr': lr})
    param = load_checkpoint(data_dir / "model")
    return (param, module)


def download_model(
    data_name: str # The name of data
    ):
    """High-level util function for download trained model."""

    # validate data name
    if data_name not in DEFAULT_DATA_CONFIGS.keys():
        raise ValueError(f'`data_name` must be one of {DEFAULT_DATA_CONFIGS.keys()}, '
            f'but got data_name={data_name}.')

    # get model urls
    _model_path = f"assets/{data_name}/model"

    # create new dir
    data_dir = Path(os.getcwd()) / "cf_data"
    if not data_dir.exists():
        os.makedirs(data_dir)
    model_path = data_dir / data_name / "model"
    if not model_path.exists():
        os.makedirs(model_path)
    model_params_url = f"https://github.com/BirkhoffG/ReLax/raw/master/{_model_path}/params.npy"
    model_tree_url = f"https://github.com/BirkhoffG/ReLax/raw/master/{_model_path}/tree.pkl"

    # download trained model
    params_path = os.path.join(model_path, "params.npy")
    tree_path = os.path.join(model_path, "tree.pkl")
    if not os.path.isfile(params_path):
        urlretrieve(model_params_url, params_path)
    if not os.path.isfile(tree_path):
        urlretrieve(model_tree_url, tree_path)

    return


# Pretrain model

In [None]:
#| slow
import shutil

for data_name in DEFAULT_DATA_CONFIGS.keys():
    datamodule = load_data(data_name = data_name)

    # Fetch the sizes and lr from the configs file
    data_dir = Path(os.getcwd()) / "cf_data" / data_name / "configs.json"
    mlp_configs = load_json(data_dir)['mlp_configs']
    sizes = mlp_configs["sizes"]
    lr = mlp_configs["lr"]
    batch_size = load_json(data_dir)["data_configs"]['batch_size']

    params, opt_state = train_model(
        PredictiveTrainingModule({'sizes': sizes, 'lr': lr}),
        datamodule, t_configs={
            'n_epochs': 10, 'batch_size': batch_size, 'monitor_metrics': 'val/val_loss',
            'max_n_checkpoints': 1, 'logger_name': data_name
        }
    )

    # get the most recent version and the best epoch stored in the version
    version_dir = "log/{data_name}/".format(data_name = data_name) # Obtain the all version
    latest_version = max([os.path.join(version_dir,v) for v in os.listdir(version_dir) if v.startswith("version_")], key=os.path.getmtime)
    epoch = [d for d in os.listdir(latest_version + "/checkpoints/".format(data_name = data_name, version = latest_version)) if d.startswith("epoch")][0] # Obtain the epoch value
    model_dir = latest_version + "/checkpoints/{epoch}/model".format(epoch = epoch)

    # update model to the assets
    shutil.rmtree("assets/{data_name}/model".format(data_name=data_name), ignore_errors=True)
    shutil.copytree(model_dir, "assets/{data_name}/model".format(data_name = data_name))

    # test: save model under cf_data
    shutil.rmtree("cf_data/{data_name}/model".format(data_name=data_name), ignore_errors=True)
    shutil.copytree(model_dir, "cf_data/{data_name}/model".format(data_name = data_name))

Epoch 9: 100%|██████████| 96/96 [00:00<00:00, 348.92batch/s, train/train_loss_1=0.0706]
Epoch 9: 100%|██████████| 31/31 [00:00<00:00, 367.97batch/s, train/train_loss_1=0.117] 
Epoch 9: 100%|██████████| 96/96 [00:00<00:00, 317.62batch/s, train/train_loss_1=0.0331]
Epoch 9: 100%|██████████| 88/88 [00:00<00:00, 337.16batch/s, train/train_loss_1=0.087] 
Epoch 9: 100%|██████████| 14/14 [00:00<00:00, 648.68batch/s, train/train_loss_1=0.0509]
Epoch 9: 100%|██████████| 16/16 [00:00<00:00, 691.80batch/s, train/train_loss_1=0.129]
Epoch 9: 100%|██████████| 11/11 [00:00<00:00, 554.94batch/s, train/train_loss_1=0.0687]
Epoch 9: 100%|██████████| 12/12 [00:00<00:00, 570.37batch/s, train/train_loss_1=0.057]
Epoch 9: 100%|██████████| 14/14 [00:00<00:00, 340.55batch/s, train/train_loss_1=0.0338]
Epoch 9: 100%|██████████| 8/8 [00:00<00:00, 383.47batch/s, train/train_loss_1=0.0232]
Epoch 9: 100%|██████████| 7/7 [00:00<00:00, 523.45batch/s, train/train_loss_1=0.1] 
Epoch 9: 100%|██████████| 11/11 [00:00<0

# Test

In [None]:
#| slow
from sklearn.metrics import accuracy_score
log = {"name":[], "accuracy":[]}
for data_name in DEFAULT_DATA_CONFIGS.keys():
    datamodule = load_data(data_name = data_name)
    params, module = load_pred_model(data_name)
    x,y_true = datamodule.test_dataset[:]
    y_pred = module.pred_fn(x = x, params = params, rng_key = random.PRNGKey(0))
    assert y_pred.shape == (x.shape[0],1)

    # calculate accuracy
    y_pred = y_pred > 0.5
    y_pred = y_pred.astype(int)
    accuracy = accuracy_score(y_true,y_pred)

    log["name"].append(data_name)
    log["accuracy"].append(accuracy)



# Random Forest 

In [None]:
#| slow
from sklearn.ensemble import RandomForestClassifier

log["rfc accuracy"] = []
for data_name in DEFAULT_DATA_CONFIGS.keys():
    rfc = RandomForestClassifier(random_state=0)
    datamodule = load_data(data_name = data_name)
    X_train, y_train = datamodule.train_dataset[:]
    rfc.fit(X_train, y_train)
    X_test, y_test = datamodule.test_dataset[:]
    y_pred = rfc.predict(X_test)

    # calculate accuracy
    y_pred = y_pred > 0.5
    y_pred = y_pred.astype(int)
    accuracy = accuracy_score(y_test,y_pred)
    log["rfc accuracy"].append(accuracy)

pd.DataFrame.from_dict(log)

  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)
  rfc.fit(X_train, y_train)


Unnamed: 0,name,accuracy,rfc accuracy
0,adult,0.8241,0.806166
1,heloc,0.702868,0.719312
2,oulad,0.926739,0.940361
3,credit,0.8132,0.813467
4,cancer,0.909091,0.916084
5,student_performance,0.90184,0.920245
6,titanic,0.816143,0.802691
7,german,0.756,0.756
8,spam,0.93397,0.943527
9,ozone,0.933754,0.949527


## Examples

A siimple example to train a predictive model.

In [None]:
from relax.data import TabularDataModule, load_data
from relax.module import PredictiveTrainingModule, PredictiveModelConfigs

In [None]:
datamodule = load_data('adult')

params, opt_state = train_model(
    PredictiveTrainingModule({'sizes': [50, 10, 50], 'lr': 0.003}), 
    datamodule, t_configs={
        'n_epochs': 10, 'batch_size': 256, 'monitor_metrics': 'val/val_loss'
    }
)

Epoch 9: 100%|██████████| 96/96 [00:00<00:00, 377.38batch/s, train/train_loss_1=0.0706]
