Skip to content

Commit

Permalink
Support consolidating sharded checkpoints with the fabric CLI (#19560)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 4, 2024
1 parent d9113b6 commit 13f15b3
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ It is possible to convert a distributed checkpoint to a regular, single-file che

.. code-block:: bash
python -m lightning.fabric.utilities.consolidate_checkpoint path/to/my/checkpoint
fabric consolidate path/to/my/checkpoint
You will need to do this for example if you want to load the checkpoint into a script that doesn't use FSDP, or need to export the checkpoint to a different format for deployment, evaluation, etc.

Expand All @@ -202,7 +202,7 @@ You will need to do this for example if you want to load the checkpoint into a s

.. code-block:: bash
python -m lightning.fabric.utilities.consolidate_checkpoint my-checkpoint.ckpt
fabric consolidate my-checkpoint.ckpt
This saves a new file ``my-checkpoint.ckpt.consolidated`` next to the sharded checkpoint which you can load normally in PyTorch:

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI [#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))

-

Expand Down
34 changes: 34 additions & 0 deletions src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@
from argparse import Namespace
from typing import Any, List, Optional

import torch
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import get_args

from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
from lightning.fabric.strategies import STRATEGY_REGISTRY
from lightning.fabric.utilities.consolidate_checkpoint import _process_cli_args
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
from lightning.fabric.utilities.distributed import _suggested_max_num_threads
from lightning.fabric.utilities.load import _load_distributed_checkpoint

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -154,6 +157,37 @@ def _run(**kwargs: Any) -> None:
script_args = list(kwargs.pop("script_args", []))
main(args=Namespace(**kwargs), script_args=script_args)

@_main.command(
"consolidate",
context_settings={
"ignore_unknown_options": True,
},
)
@click.argument(
"checkpoint_folder",
type=click.Path(exists=True),
)
@click.option(
"--output_file",
type=click.Path(exists=True),
default=None,
help=(
"Path to the file where the converted checkpoint should be saved. The file should not already exist."
" If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
" and a '.consolidated' suffix."
),
)
def _consolidate(checkpoint_folder: str, output_file: Optional[str]) -> None:
"""Convert a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`.
Only supports FSDP sharded checkpoints at the moment.
"""
args = Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file)
config = _process_cli_args(args)
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
torch.save(checkpoint, config.output_file)


def _set_env_variables(args: Namespace) -> None:
"""Set the environment variables for the new processes.
Expand Down
47 changes: 33 additions & 14 deletions tests/tests_fabric/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from unittest.mock import Mock

import pytest
from lightning.fabric.cli import _get_supported_strategies, _run
from lightning.fabric.cli import _consolidate, _get_supported_strategies, _run

from tests_fabric.helpers.runif import RunIf

Expand All @@ -33,7 +33,7 @@ def fake_script(tmp_path):


@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_defaults(monkeypatch, fake_script):
def test_run_env_vars_defaults(monkeypatch, fake_script):
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
with pytest.raises(SystemExit) as e:
_run.main([fake_script])
Expand All @@ -49,7 +49,7 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script):
@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
def test_cli_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
def test_run_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
with pytest.raises(SystemExit) as e:
_run.main([fake_script, "--accelerator", accelerator])
Expand All @@ -60,23 +60,23 @@ def test_cli_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
@pytest.mark.parametrize("strategy", _get_supported_strategies())
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
def test_cli_env_vars_strategy(_, strategy, monkeypatch, fake_script):
def test_run_env_vars_strategy(_, strategy, monkeypatch, fake_script):
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
with pytest.raises(SystemExit) as e:
_run.main([fake_script, "--strategy", strategy])
assert e.value.code == 0
assert os.environ["LT_STRATEGY"] == strategy


def test_cli_get_supported_strategies():
def test_run_get_supported_strategies():
"""Test to ensure that when new strategies get added, we must consider updating the list of supported ones in the
CLI."""
assert len(_get_supported_strategies()) == 7
assert "fsdp" in _get_supported_strategies()


@pytest.mark.parametrize("strategy", ["ddp_spawn", "ddp_fork", "ddp_notebook", "deepspeed_stage_3_offload"])
def test_cli_env_vars_unsupported_strategy(strategy, fake_script):
def test_run_env_vars_unsupported_strategy(strategy, fake_script):
ioerr = StringIO()
with pytest.raises(SystemExit) as e, contextlib.redirect_stderr(ioerr):
_run.main([fake_script, "--strategy", strategy])
Expand All @@ -87,7 +87,7 @@ def test_cli_env_vars_unsupported_strategy(strategy, fake_script):
@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
def test_cli_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
with pytest.raises(SystemExit) as e:
_run.main([fake_script, "--accelerator", "cuda", "--devices", devices])
Expand All @@ -98,7 +98,7 @@ def test_cli_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
@RunIf(mps=True)
@pytest.mark.parametrize("accelerator", ["mps", "gpu"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
def test_run_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
with pytest.raises(SystemExit) as e:
_run.main([fake_script, "--accelerator", accelerator])
Expand All @@ -108,7 +108,7 @@ def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script):

@pytest.mark.parametrize("num_nodes", ["1", "2", "3"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
def test_run_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
with pytest.raises(SystemExit) as e:
_run.main([fake_script, "--num-nodes", num_nodes])
Expand All @@ -118,7 +118,7 @@ def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):

@pytest.mark.parametrize("precision", ["64-true", "64", "32-true", "32", "16-mixed", "bf16-mixed"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_precision(precision, monkeypatch, fake_script):
def test_run_env_vars_precision(precision, monkeypatch, fake_script):
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
with pytest.raises(SystemExit) as e:
_run.main([fake_script, "--precision", precision])
Expand All @@ -127,7 +127,7 @@ def test_cli_env_vars_precision(precision, monkeypatch, fake_script):


@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_torchrun_defaults(monkeypatch, fake_script):
def test_run_torchrun_defaults(monkeypatch, fake_script):
torchrun_mock = Mock()
monkeypatch.setitem(sys.modules, "torch.distributed.run", torchrun_mock)
with pytest.raises(SystemExit) as e:
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_cli_torchrun_defaults(monkeypatch, fake_script):
)
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5)
def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch, fake_script):
def test_run_torchrun_num_processes_launched(_, devices, expected, monkeypatch, fake_script):
torchrun_mock = Mock()
monkeypatch.setitem(sys.modules, "torch.distributed.run", torchrun_mock)
with pytest.raises(SystemExit) as e:
Expand All @@ -171,15 +171,15 @@ def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch,
])


def test_cli_through_fabric_entry_point():
def test_run_through_fabric_entry_point():
result = subprocess.run("fabric run --help", capture_output=True, text=True, shell=True)

message = "Usage: fabric run [OPTIONS] SCRIPT [SCRIPT_ARGS]"
assert message in result.stdout or message in result.stderr


@pytest.mark.skipif("lightning.fabric" == "lightning_fabric", reason="standalone package")
def test_cli_through_lightning_entry_point():
def test_run_through_lightning_entry_point():
result = subprocess.run("lightning run model --help", capture_output=True, text=True, shell=True)

deprecation_message = (
Expand All @@ -189,3 +189,22 @@ def test_cli_through_lightning_entry_point():
message = "Usage: lightning run [OPTIONS] SCRIPT [SCRIPT_ARGS]"
assert deprecation_message in result.stdout
assert message in result.stdout or message in result.stderr


@mock.patch("lightning.fabric.cli._process_cli_args")
@mock.patch("lightning.fabric.cli._load_distributed_checkpoint")
@mock.patch("lightning.fabric.cli.torch.save")
def test_consolidate(save_mock, _, __, tmp_path):
ioerr = StringIO()
with pytest.raises(SystemExit) as e, contextlib.redirect_stderr(ioerr):
_consolidate.main(["not exist"])
assert e.value.code == 2
assert "Path 'not exist' does not exist" in ioerr.getvalue()

checkpoint_folder = tmp_path / "checkpoint"
checkpoint_folder.mkdir()
ioerr = StringIO()
with pytest.raises(SystemExit) as e, contextlib.redirect_stderr(ioerr):
_consolidate.main([str(checkpoint_folder)])
assert e.value.code == 0
save_mock.assert_called_once()

0 comments on commit 13f15b3

Please sign in to comment.