# Checkpoint Manager

> Manage the model and optimizer checkpoints

In [None]:
#| default_exp _ckpt_manager

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *

In [None]:
#| export
from relax.import_essentials import *
from collections import OrderedDict

In [None]:
#| export
# https://github.com/deepmind/dm-haiku/issues/18#issuecomment-981814403
def save_checkpoint(state, ckpt_dir: Path):
    with open(os.path.join(ckpt_dir, "params.npy"), "wb") as f:
        for x in jax.tree_util.tree_leaves(state):
            np.save(f, x, allow_pickle=False)

    tree_struct = jax.tree_util.tree_map(lambda t: 0, state)
    with open(os.path.join(ckpt_dir, "tree.pkl"), "wb") as f:
        pickle.dump(tree_struct, f)


def load_checkpoint(ckpt_dir: Path):
    with open(os.path.join(ckpt_dir, "tree.pkl"), "rb") as f:
        tree_struct = pickle.load(f)

    leaves, treedef = jax.tree_util.tree_flatten(tree_struct)
    with open(os.path.join(ckpt_dir, "params.npy"), "rb") as f:
        flat_state = [np.load(f) for _ in leaves]

    return jax.tree_util.tree_unflatten(treedef, flat_state)


In [None]:
#| export
class CheckpointManager:
    def __init__(
        self,
        log_dir: Union[Path, str],
        monitor_metrics: Optional[str],
        max_n_checkpoints: int = 3,
    ):
        self.log_dir = Path(log_dir)
        self.monitor_metrics = monitor_metrics
        self.max_n_checkpoints = max_n_checkpoints
        self.checkpoints = OrderedDict()
        self.n_checkpoints = 0
        if self.monitor_metrics is None:
            warnings.warn(
                "`monitor_metrics` is not specified in `CheckpointManager`. No checkpoints will be stored."
            )

    # update checkpoints based on monitor_metrics
    def update_checkpoints(
        self,
        params: hk.Params,
        opt_state: optax.OptState,
        epoch_logs: Dict[str, float],
        epochs: int,
        steps: Optional[int] = None,
    ):
        if self.monitor_metrics is None:
            return
        if self.monitor_metrics not in epoch_logs:
            raise ValueError(
                "The monitor_metrics ({}) is not appropriately configured.".format(
                    self.monitor_metrics
                )
            )
        metric = float(epoch_logs[self.monitor_metrics])
        if steps:
            ckpt_name = f"epoch={epochs}_step={steps}"
        else:
            ckpt_name = f"epoch={epochs}"

        if self.n_checkpoints < self.max_n_checkpoints:
            self.checkpoints[metric] = ckpt_name
            self.save_net_opt(params, opt_state, ckpt_name)
            self.n_checkpoints += 1
        else:
            old_metric, old_ckpt_name = self.checkpoints.popitem(last=True)
            if metric < old_metric:
                self.checkpoints[metric] = ckpt_name
                self.save_net_opt(params, opt_state, ckpt_name)
                self.delete_net_opt(old_ckpt_name)
            else:
                self.checkpoints[old_metric] = old_ckpt_name

        self.checkpoints = OrderedDict(
            sorted(self.checkpoints.items(), key=lambda x: x[0])
        )

    def save_net_opt(self, params, opt_state, ckpt_name: str):
        ckpt_dir = self.log_dir / f"{ckpt_name}"
        ckpt_dir.mkdir(parents=True, exist_ok=True)
        model_ckpt_dir = ckpt_dir / "model"
        opt_ckpt_dir = ckpt_dir / "opt"
        # create dirs for storing states of model and optimizer
        model_ckpt_dir.mkdir(parents=True, exist_ok=True)
        opt_ckpt_dir.mkdir(parents=True, exist_ok=True)
        # save model and optimizer states
        save_checkpoint(params, model_ckpt_dir)
        save_checkpoint(opt_state, opt_ckpt_dir)

    def delete_net_opt(self, ckpt_name: str):
        ckpt_dir = self.log_dir / f"{ckpt_name}"
        shutil.rmtree(ckpt_dir)

#### Example


In [None]:
from relax.data import load_data
from relax.module import PredictiveTrainingModule


In [None]:
key = hk.PRNGSequence(42)
ckpt_manager = CheckpointManager(
    log_dir='log', 
    monitor_metrics='train/train_loss_1',
    max_n_checkpoints=3
)
dm = load_data('adult')
module = PredictiveTrainingModule({'lr': 0.01, 'sizes': [50, 10, 50]})
params, opt_state = module.init_net_opt(dm, next(key))
logs = {'train/train_loss_1': 0.1}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=1)
logs = {'train/train_loss_1': 0.2}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=2)
logs = {'train/train_loss_1': 0.15}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=3)
logs = {'train/train_loss_1': 0.05}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=4)
logs = {'train/train_loss_1': 0.14}
ckpt_manager.update_checkpoints(params, opt_state, logs, epochs=5)
assert ckpt_manager.n_checkpoints == len(ckpt_manager.checkpoints)
assert ckpt_manager.checkpoints.popitem(last=True)[0] == 0.14

shutil.rmtree(Path('log/epoch=1'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=2'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=3'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=4'), ignore_errors=True)
shutil.rmtree(Path('log/epoch=5'), ignore_errors=True)

