Skip to content

Commit

Permalink
Automatic model checkpointing for pytorch-lightning training (mlflow#…
Browse files Browse the repository at this point in the history
…10935)

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
Signed-off-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal>
Co-authored-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
Co-authored-by: Ubuntu <weichen.xu@ip-10-110-25-111.us-west-2.compute.internal>
  • Loading branch information
3 people authored and sateeshmannar committed Feb 20, 2024
1 parent e954278 commit babc881
Show file tree
Hide file tree
Showing 6 changed files with 632 additions and 2 deletions.
4 changes: 4 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@
("py:class", "keras.src.callbacks.callback.Callback"),
("py:class", "keras.callbacks.Callback"),
("py:class", "keras.src.callbacks.Callback"),
("py:class", "pytorch_lightning.callbacks.callback.Callback"),
("py:class", "pytorch_lightning.trainer.trainer.Trainer"),
("py:class", "pytorch_lightning.core.module.LightningModule"),
("py:class", "pytorch_lightning.core.LightningModule"),
]


Expand Down
2 changes: 1 addition & 1 deletion mlflow/pyfunc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,7 +2171,7 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]:
with mlflow.start_run():
model_info = mlflow.pyfunc.log_model(artifact_path="model", python_model=MyModel()) # pylint: disable=line-too-long
model_info = mlflow.pyfunc.log_model(artifact_path="model", python_model=MyModel()) # noqa # pylint: disable=line-too-long
loaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
Expand Down
128 changes: 128 additions & 0 deletions mlflow/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import mlflow
from mlflow import pyfunc
from mlflow.client import MlflowClient
from mlflow.environment_variables import MLFLOW_DEFAULT_PREDICTION_DEVICE
from mlflow.exceptions import MlflowException
from mlflow.ml_package_versions import _ML_PACKAGE_VERSIONS
Expand Down Expand Up @@ -902,6 +903,12 @@ def autolog(
silent=False,
registered_model_name=None,
extra_tags=None,
checkpoint=True,
checkpoint_monitor="val_loss",
checkpoint_mode="min",
checkpoint_save_best_only=True,
checkpoint_save_weights_only=False,
checkpoint_save_freq="epoch",
): # pylint: disable=unused-argument
"""
Enables (or disables) and configures autologging from `PyTorch Lightning
Expand Down Expand Up @@ -956,6 +963,25 @@ def autolog(
new model version of the registered model with this name. The registered model is
created if it does not already exist.
extra_tags: A dictionary of extra tags to set on each managed run created by autologging.
checkpoint: Enable automatic model checkpointing, this feature only supports
pytorch-lightning >= 1.6.0.
checkpoint_monitor: In automatic model checkpointing, the metric name to monitor if
you set `model_checkpoint_save_best_only` to True.
checkpoint_save_best_only: If True, automatic model checkpointing only saves when
the model is considered the "best" model according to the quantity
monitored and previous checkpoint model is overwritten.
checkpoint_mode: one of {"min", "max"}. In automatic model checkpointing,
if save_best_only=True, the decision to overwrite the current save file is made based on
either the maximization or the minimization of the monitored quantity.
checkpoint_save_weights_only: In automatic model checkpointing, if True, then
only the model’s weights will be saved. Otherwise, the optimizer states,
lr-scheduler states, etc are added in the checkpoint too.
checkpoint_save_freq: `"epoch"` or integer. When using `"epoch"`, the callback
saves the model after each epoch. When using integer, the callback
saves the model at end of this many batches. Note that if the saving isn't aligned to
epochs, the monitored metric may potentially be less reliable (it
could reflect as little as 1 batch, since the metrics get reset
every epoch). Defaults to `"epoch"`.
.. code-block:: python
:test:
Expand Down Expand Up @@ -1101,3 +1127,105 @@ def print_auto_logged_info(r):
autolog.__doc__ = autolog.__doc__.replace("MIN_REQ_VERSION", str(MIN_REQ_VERSION)).replace(
"MAX_REQ_VERSION", str(MAX_REQ_VERSION)
)


def load_checkpoint(model_class, run_id=None, epoch=None, global_step=None):
"""
If you enable "checkpoint" in autologging, during pytorch-lightning model
training execution, checkpointed models are logged as MLflow artifacts.
Using this API, you can load the checkpointed model.
If you want to load the latest checkpoint, set both `epoch` and `global_step` to None.
If "checkpoint_save_freq" is set to "epoch" in autologging,
you can set `epoch` param to the epoch of the checkpoint to load specific epoch checkpoint.
If "checkpoint_save_freq" is set to an integer in autologging,
you can set `global_step` param to the global step of the checkpoint to load specific
global step checkpoint.
`epoch` param and `global_step` can't be set together.
Args:
model_class: The class of the training model, the class should inherit
'pytorch_lightning.LightningModule'.
run_id: The id of the run which model is logged to. If not provided,
current active run is used.
epoch: The epoch of the checkpoint to be loaded, if you set
"checkpoint_save_freq" to "epoch".
global_step: The global step of the checkpoint to be loaded, if
you set "checkpoint_save_freq" to an integer.
Returns:
The instance of a pytorch-lightning model restored from the specified checkpoint.
.. code-block:: python
:caption: Example
import mlflow.pytorch
mlflow.pytorch.autolog(checkpoint=True)
model = MyLightningModuleNet() # A custom-pytorch lightning model
trainer = Trainer()
with mlflow.start_run() as run:
trainer.fit(net)
run_id = run.info.run_id
# load latest checkpoint model
latest_checkpoint_model = mlflow.pytorch.load_checkpoint(MyLightningModuleNet, run_id)
# load history checkpoint model logged in second epoch
checkpoint_model = mlflow.pytorch.load_checkpoint(MyLightningModuleNet, run_id, epoch=2)
"""
from mlflow.utils.mlflow_tags import LATEST_CHECKPOINT_ARTIFACT_TAG_KEY

client = MlflowClient()

if run_id is None:
run = mlflow.active_run()
if run is None:
raise MlflowException(
"There is no active run, please provide the 'run_id' for " "'load_checkpoint' call."
)
run_id = run.info.run_id
else:
run = client.get_run(run_id)

latest_checkpoint_artifact_path = run.data.tags.get(LATEST_CHECKPOINT_ARTIFACT_TAG_KEY)
if latest_checkpoint_artifact_path is None:
raise MlflowException("There is no logged checkpoint artifact in the current run.")

checkpoint_filename = os.path.basename(latest_checkpoint_artifact_path)

if epoch is not None and global_step is not None:
raise MlflowException(
"Only one of 'epoch' and 'global_step' can be set for 'load_checkpoint'."
)
elif global_step is not None:
checkpoint_artifact_path = f"checkpoints/global_step_{global_step}/{checkpoint_filename}"
elif epoch is not None:
checkpoint_artifact_path = f"checkpoints/epoch_{epoch}/{checkpoint_filename}"
else:
checkpoint_artifact_path = latest_checkpoint_artifact_path

downloaded_checkpoint_filepath = client.download_artifacts(run_id, checkpoint_artifact_path)
return model_class.load_from_checkpoint(downloaded_checkpoint_filepath)


__all__ = [
"autolog",
"load_model",
"save_model",
"log_model",
"get_default_pip_requirements",
"get_default_conda_env",
"load_checkpoint",
]

try:
from mlflow.pytorch._lightning_autolog import MlflowModelCheckpointCallback # noqa: F401

__all__.append("MLflowModelCheckpointCallback")
except ImportError:
# Swallow exception if pytorch-lightning is not installed.
pass
Loading

0 comments on commit babc881

Please sign in to comment.