> [!NOTE]
>
> This notebook was inspired by [https://orobix.github.io/quadra/v1.3.6/tutorials/model_management.html](https://orobix.github.io/quadra/v1.3.6/tutorials/model_management.html)

# Model Manager

In this notebook, we present the [MlflowModelManager](../sheeprl/utils/model_manager.py) and possible use.
It includes methods such as:
* Register the model
* Retrieve the latest version
* Transition the model to a new stage
* Delete the model

First of all, we need to run the Mlflow server with the artifact store. You can find the instructions for running the Mlflow server [here](https://mlflow.org/docs/latest/tracking.html#tracking-ui). Let's open a new terminal and run the following command:
```bash
mlflow ui
```

> [!NOTE]
>
> This is one of the possibilities, you could have the server running on another machine, so you just need to set the `tracking_uri` parameter properly.

### Running the Experiment and Registering the Model
Second, we launch an experiment, so we need to retrieve the configs and execute the `run_algorithm` function. We train a PPO agent in the CartPole-v1 environment for few steps (we do not want to reach the best performance, but we want to show how SheepRL interprets model management for reinforcement learning).

In [None]:
import hydra
from omegaconf import OmegaConf
from sheeprl.utils.utils import dotdict
from sheeprl.cli import check_configs, run_algorithm

# To retrieve the configs, we can simulate the cli command
# `python sheeprl.py exp=ppo algo.total_steps=1024 model_manager.disabled=False logger@metric.logger=mlflow checkpoint.every=1024 exp_name=mlflow_example metric.logger.tracking_uri="http://localhost:5000"`
with hydra.initialize(version_base="1.3", config_path="../sheeprl/configs"):
    cfg = hydra.compose(
        config_name="config.yaml",
        overrides=[
            "exp=ppo",
            "algo.total_steps=1024",
            "model_manager.disabled=False",
            "logger@metric.logger=mlflow",
            "checkpoint.every=1024",
            "exp_name=mlflow_example",
            "metric.logger.tracking_uri=http://localhost:5000",
        ],
    )
    cfg = dotdict(OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True))
check_configs(cfg)
run_algorithm(cfg)

### Get Experiment Info

The experiment is logged on MLFlow, and we can retrieve it just  with the following instructions. Moreover, given the experiment, it is possible to retrieve all the runs with the `mlflow.search_runs()` function.

> [!NOTE]
>
> You can check this information from a browser, by entering the MLFlow address in a browser, e.g., `http://localhost:5000` if you are running mlflow locally.

In [2]:
import mlflow

mlflow.set_tracking_uri(cfg.metric.logger.tracking_uri)
exp = mlflow.get_experiment_by_name("mlflow_example")
print("Experiment:", exp)
runs = mlflow.search_runs(experiment_ids=[exp.experiment_id])
print(f"Experiment ({exp.experiment_id}) runs:")
runs

Experiment: <Experiment: artifact_location='mlflow-artifacts:/242317125620601262', creation_time=1701949559261, experiment_id='242317125620601262', last_update_time=1701949559261, lifecycle_stage='active', name='mlflow_example', tags={}>
Experiment (242317125620601262) runs:


Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.Loss/entropy_loss,metrics.Test/cumulative_reward,metrics.Info/ent_coef,metrics.Info/learning_rate,...,params.algo/gae_lambda,params.env/action_repeat,params.env/grayscale,params.metric/aggregator/metrics/Loss/policy_loss/sync_on_compute,params.metric/log_level,tags.mlflow.user,tags.mlflow.source.type,tags.mlflow.runName,tags.mlflow.source.name,tags.mlflow.log-model.history
0,1e453cf2114d43f28410803df985598a,242317125620601262,FINISHED,mlflow-artifacts:/242317125620601262/1e453cf21...,2023-12-07 11:45:59.641000+00:00,2023-12-07 11:46:10.350000+00:00,-0.687031,48.0,0.0,0.001,...,0.95,1,False,False,1,mmilesi,LOCAL,ppo_CartPole-v1_2023-12-07_12-45-58,/home/mmilesi/miniconda3/envs/sheeprl/lib/pyth...,"[{""run_id"": ""1e453cf2114d43f28410803df985598a""..."


### Retrieve Model Info
Since we set the `model_manager.disabled` to `False` the PPO Agent is registered in MLFLow, we can get its information with the following instructions.

In [3]:
from sheeprl.utils.mlflow import MlflowModelManager
from lightning import Fabric

fabric = Fabric(devices=1, accelerator=cfg.fabric.accelerator, precision=cfg.fabric.precision)
fabric.launch()
model_manager = MlflowModelManager(fabric, cfg.model_manager.tracking_uri)

model_info = mlflow.search_registered_models(filter_string="name='mlflow_example_agent'")[-1]
model_name = model_info.name
print("Name:", model_name)
print("Description:", model_info.description)
print("Tags:", model_info.tags)
latest_version = model_manager.get_latest_version(model_info.name)
print("Latest Version:", latest_version.version)

Name: mlflow_example_agent
Description: # MODEL CHANGELOG
## **Version 1**
### Author: mmilesi
### Date: 07/12/2023 12:46:10 CET
### Description: 
PPO Agent in CartPole-v1 Environment

Tags: {}
Latest Version: 1


### Registering a New Model Version from Checkpoint

Suppose to train a new PPO Agent in the CartPole-v1 environment and to obtain better results than before. You can register a new version of the model. To do this, we show another method to register models, not directly after training, but from a checkpoint.

First of all, we need to run another experiment with different hyper-parameters.

In [None]:
# To retrieve the configs, we can simulate the cli command
# `python sheeprl.py exp=ppo algo.total_steps=16384 checkpoint.every=16384 logger@metric.logger=mlflow exp_name=mlflow_example metric.logger.tracking_uri="http://localhost:5000"`
import os

with hydra.initialize(version_base="1.3", config_path="../sheeprl/configs"):
    cfg_ = hydra.compose(
        config_name="config.yaml",
        overrides=[
            "exp=ppo",
            "algo.total_steps=16384",
            "checkpoint.every=16384",
            "logger@metric.logger=mlflow",
            "exp_name=mlflow_example",
            "metric.logger.tracking_uri=http://localhost:5000",
        ],
    )
    cfg = dotdict(OmegaConf.to_container(cfg_, resolve=True, throw_on_missing=True))
run_algorithm(cfg)
os.mkdir(f"./logs/runs/{cfg.root_dir}/{cfg.run_name}/.hydra/")
OmegaConf.save(cfg_, f"./logs/runs/{cfg.root_dir}/{cfg.run_name}/.hydra/config.yaml")

Now we can use the `./sheeprl_model_manager.py` script to take a checkpoint and register the models of the checkpoint.
We want to retrieve the id of the last run, to associate the model to the correct run. We can take it from the UI (from the browser) or by retrieving it with the `mlflow.search_runs(experiment_ids=[exp.experiment_id])` instruction.

In [5]:
from sheeprl.cli import registration

# To retrieve the configs, we can simulate the cli command
# `python sheeprl_model_manager.py checkpoint_path=</path/to/checkpoint.ckpt> \
# model_manager=ppo model_manager.models.agent.description='New PPO Agent version trained in CartPole-v1 environment' \
# run.id=<run_id>`
runs = mlflow.search_runs(experiment_ids=[exp.experiment_id])
run_id = runs["run_id"][0]
with hydra.initialize(version_base="1.3", config_path="../sheeprl/configs"):
    cfg = hydra.compose(
        config_name="model_manager_config.yaml",
        overrides=[
            # Substitute the checkpoint path with your /path/to/checkpoint.ckpt
            "checkpoint_path=./path/to/checkpoint.ckpt",
            "model_manager=ppo",
            "model_manager.models.agent.description='New PPO Agent version trained in CartPole-v1 environment'",
            f"run.id={run_id}",
        ],
    )
registration(cfg)

Registered model 'mlflow_example_agent' already exists. Creating a new version of this model...
2023/12/07 12:47:04 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: mlflow_example_agent, version 2


Registered model mlflow_example_agent with version 2


Created version '2' of model 'mlflow_example_agent'.


And, of course, we can retrieve the new information of the registered model.

In [6]:
model_info = mlflow.search_registered_models(filter_string=f"name='{model_name}'")[-1]
print("Name:", model_info.name)
print("Description:", model_info.description)
print("Tags:", model_info.tags)
latest_version = model_manager.get_latest_version(model_info.name)
print("Latest Version:", latest_version.version)

Name: mlflow_example_agent
Description: # MODEL CHANGELOG
## **Version 1**
### Author: mmilesi
### Date: 07/12/2023 12:46:10 CET
### Description: 
PPO Agent in CartPole-v1 Environment
## **Version 2**
### Author: mmilesi
### Date: 07/12/2023 12:47:04 CET
### Description: 
New PPO Agent version trained in CartPole-v1 environment

Tags: {}
Latest Version: 2


### Staging the Model
After registering the model, we can transition the model to a new stage. We can transition the model to the `"staging"` stage with the following command.

In [8]:
model_manager.transition_model(
    model_name="mlflow_example_agent",
    version=latest_version.version,
    stage="staging",
    description="Staging Model for demo",
)

Transitioning model mlflow_example_agent version 2 from None to staging


<ModelVersion: aliases=[], creation_timestamp=1701949624027, current_stage='Staging', description=('# MODEL CHANGELOG\n'
 '## **Version 2**\n'
 '### Author: mmilesi\n'
 '### Date: 07/12/2023 12:47:04 CET\n'
 '### Description: \n'
 'New PPO Agent version trained in CartPole-v1 environment\n'), last_updated_timestamp=1701949660778, name='mlflow_example_agent', run_id='eefbe09e8815463eaa83c6542cbc36c7', run_link='', source='mlflow-artifacts:/242317125620601262/eefbe09e8815463eaa83c6542cbc36c7/artifacts/agent', status='READY', status_message='', tags={}, user_id='', version='2'>

### Downloading the Model
You can download the registered models and load them with the `torch.load()` function.

In [None]:
import torch

download_path = "./models/ppo-agent-cartpole"
model_manager.download_model(model_name, latest_version.version, download_path)
agent = torch.load("models/ppo-agent-cartpole/agent/data/model.pth")
agent

### Register Best Models
Another possibility is to register the best models of a specific experiment. Let us suppose we want to register the best model of the two experiments we ran before: the only thing we have to do is to call the `model_manager.register_best_models()` function by specifying the `experiment_name`, the `metric`, and the `models_info` (a python dictionary containing the name, the path, the description and the tags of the models we want to register), as shown below.

> [!NOTE]
>
> If your experiment contains different agents, and each agent has different model paths, then you have to specify in the `models_info` all the models you want to register (i.e., the union of the models of all the agents). The MLFlow model manager will automatically select the correct models for each agent.

In [10]:
models_info = {
    "agent": {
        "name": "ppo_agent_cartpole_best_reward",
        "path": "agent",
        "tags": {},
        "description": "The best PPO Agent in CartPole environment.",
    }
}
model_manager.register_best_models("mlflow_example", models_info)

Successfully registered model 'ppo_agent_cartpole_best_reward'.
2023/12/07 12:47:55 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: ppo_agent_cartpole_best_reward, version 1


Registered model ppo_agent_cartpole_best_reward with version 1


Created version '1' of model 'ppo_agent_cartpole_best_reward'.


{'agent': <ModelVersion: aliases=[], creation_timestamp=1701949675859, current_stage='None', description='', last_updated_timestamp=1701949675859, name='ppo_agent_cartpole_best_reward', run_id='eefbe09e8815463eaa83c6542cbc36c7', run_link='', source='mlflow-artifacts:/242317125620601262/eefbe09e8815463eaa83c6542cbc36c7/artifacts/agent', status='READY', status_message='', tags={}, user_id='', version='1'>}

### Delete Model
Finally, you can delete registered models you no longer need.

In [12]:
model_manager.delete_model(
    model_name, int(latest_version.version) - 1, f"Delete model version {int(latest_version.version)-1}"
)
mlflow.search_registered_models(filter_string="name='mlflow_example_agent'")[-1]

Model named mlflow_example_agent with version 1 does not exist


<RegisteredModel: aliases={}, creation_timestamp=1701949570369, description=('# MODEL CHANGELOG\n'
 '## **Version 1**\n'
 '### Author: mmilesi\n'
 '### Date: 07/12/2023 12:46:10 CET\n'
 '### Description: \n'
 'PPO Agent in CartPole-v1 Environment\n'
 '## **Version 2**\n'
 '### Author: mmilesi\n'
 '### Date: 07/12/2023 12:47:04 CET\n'
 '### Description: \n'
 'New PPO Agent version trained in CartPole-v1 environment\n'
 '## **Transition:**\n'
 '### Version 2 from None to Staging\n'
 '### Author: mmilesi\n'
 '### Date: 07/12/2023 12:47:40 CET\n'
 '### Description: \n'
 'Staging Model for demo\n'
 '## **Deletion:**\n'
 '### Version 1 from stage: None\n'
 '### Author: mmilesi\n'
 '### Date: 07/12/2023 12:48:36 CET\n'
 '### Description: \n'
 'Delete model version 1\n'), last_updated_timestamp=1701949716092, latest_versions=[<ModelVersion: aliases=[], creation_timestamp=1701949624027, current_stage='Staging', description=('# MODEL CHANGELOG\n'
 '## **Version 2**\n'
 '### Author: mmilesi\n'
 '