Skip to content

Commit

Permalink
Revise checkpoint consolidation with PyTorch 2.3 (#19561)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 4, 2024
1 parent 527d071 commit b3c869f
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 91 deletions.
6 changes: 3 additions & 3 deletions src/lightning/fabric/utilities/consolidate_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.load import _METADATA_FILENAME, _load_distributed_checkpoint

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -38,8 +38,8 @@ def _parse_cli_args() -> Namespace:


def _process_cli_args(args: Namespace) -> Namespace:
if not _TORCH_GREATER_EQUAL_2_1:
_log.error("Processing distributed checkpoints requires PyTorch >= 2.1.")
if not _TORCH_GREATER_EQUAL_2_3:
_log.error("Processing distributed checkpoints requires PyTorch >= 2.3.")
exit(1)

checkpoint_folder = Path(args.checkpoint_folder)
Expand Down
69 changes: 12 additions & 57 deletions src/lightning/fabric/utilities/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Tuple, Union
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Union

import torch
from lightning_utilities.core.apply_func import apply_to_collection
Expand All @@ -27,8 +27,7 @@

from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
_TORCH_GREATER_EQUAL_2_2,
_TORCH_GREATER_EQUAL_2_3,
)
from lightning.fabric.utilities.types import _PATH, _Stateful

Expand Down Expand Up @@ -243,68 +242,24 @@ def _load_distributed_checkpoint(checkpoint_folder: Path) -> Dict[str, Any]:
The current implementation assumes that the entire checkpoint fits in CPU memory.
"""
if not _TORCH_GREATER_EQUAL_2_1:
raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.1.")
if not _TORCH_GREATER_EQUAL_2_3:
raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.3.")

from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict

if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed.checkpoint import load
else:
from torch.distributed.checkpoint import load_state_dict as load # deprecated

reader = FileSystemReader(checkpoint_folder)
metadata = reader.read_metadata()

# TODO: Add sequential save to avoid storing the entire checkpoint in memory
checkpoint: Dict[str, Any] = {}
for tensor_name, sd_metadata in metadata.state_dict_metadata.items():
if isinstance(sd_metadata, BytesStorageMetadata):
checkpoint[tensor_name] = "<bytes_io>"
elif isinstance(sd_metadata, TensorStorageMetadata):
checkpoint[tensor_name] = torch.empty(
size=sd_metadata.size,
dtype=sd_metadata.properties.dtype,
device=torch.device("cpu"),
memory_format=sd_metadata.properties.memory_format,
layout=sd_metadata.properties.layout,
requires_grad=sd_metadata.properties.requires_grad,
pin_memory=sd_metadata.properties.pin_memory,
)

load(state_dict=checkpoint, storage_reader=reader, no_dist=True)
checkpoint = _unflatten_dict(checkpoint, key_map=metadata.planner_data)
_load_state_dict(
checkpoint,
storage_reader=FileSystemReader(checkpoint_folder),
planner=_EmptyStateDictLoadPlanner(),
no_dist=True,
)

# This is the extra file saved by Fabric, with user data separate from weights and optimizer states
extra_file = checkpoint_folder / _METADATA_FILENAME
extra = torch.load(extra_file, map_location="cpu") if extra_file.is_file() else {}
checkpoint.update(extra)

return checkpoint


def _unflatten_dict(checkpoint: Dict[str, Any], key_map: Dict[str, Tuple[str, ...]]) -> Dict[str, Any]:
"""Converts the flat dictionary with keys 'x.y.z...' to a nested dictionary using the provided key map.
Args:
checkpoint: The flat checkpoint dictionary.
key_map: A dictionary that maps the keys in flattened format 'x.y.z...' to a tuple representing
the index path into the nested dictonary that this function should construct.
"""
assert checkpoint.keys() == key_map.keys()
converted: Dict[str, Any] = {}
for flat_key in checkpoint:
key_path = key_map[flat_key]
_set_nested_dict_value(converted, key_path, checkpoint[flat_key])
return converted


def _set_nested_dict_value(nested_dict: Dict[str, Any], key_path: Tuple[str, ...], value: Any) -> None:
result = nested_dict
for key in key_path[:-1]:
if key not in result:
result[key] = {}
result = result[key]
result[key_path[-1]] = value
6 changes: 3 additions & 3 deletions tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,7 @@ def test_clip_gradients(clip_type, precision):
optimizer.zero_grad()


# TODO: Support checkpoint consolidation with PyTorch >= 2.2
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", max_torch="2.2.0")
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")
def test_save_sharded_and_consolidate_and_load(tmp_path):
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""

Expand All @@ -639,7 +638,8 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):
state = {"model": model, "optimizer": optimizer, "steps": 1}

# run one iteration to init the state of the optimizer
model(torch.rand(1, 32, device=fabric.device)).sum().backward()
loss = model(torch.rand(1, 32, device=fabric.device)).sum()
fabric.backward(loss)
optimizer.step()

checkpoint_path_sharded = fabric.broadcast(str(tmp_path / "checkpoint_sharded"))
Expand Down
8 changes: 4 additions & 4 deletions tests/tests_fabric/utilities/test_consolidate_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def test_parse_cli_args(args, expected):


def test_process_cli_args(tmp_path, caplog, monkeypatch):
# PyTorch version < 2.1
monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_1", False)
# PyTorch version < 2.3
monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_3", False)
with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
SystemExit
):
_process_cli_args(Namespace())
assert "requires PyTorch >= 2.1." in caplog.text
assert "requires PyTorch >= 2.3." in caplog.text
caplog.clear()
monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_1", True)
monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_3", True)

# Checkpoint does not exist
checkpoint_folder = Path("does/not/exist")
Expand Down
22 changes: 0 additions & 22 deletions tests/tests_fabric/utilities/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
_materialize_tensors,
_move_state_into,
_NotYetLoadedTensor,
_unflatten_dict,
)

from tests_fabric.helpers.runif import RunIf
Expand Down Expand Up @@ -145,24 +144,3 @@ def load_state_dict(self, state_dict):
assert source == {}
assert destination["cocofruit"] == 2
assert destination["banana"].count == 100


def test_unflatten_dict():
assert _unflatten_dict({}, {}) == {}

tensor0 = torch.rand(2, 2)
tensor1 = torch.tensor(3.0)
data = {
"model.layer.weight": tensor0,
"optimizer.state.layer.weight.exp_avg": {"test": tensor1},
"optimizer.param_groups": "param_groups",
}
key_map = {
"model.layer.weight": ("model", "layer.weight"),
"optimizer.state.layer.weight.exp_avg": ("optimizer", "state", "layer.weight", "exp_avg"),
"optimizer.param_groups": ("optimizer", "param_groups"),
}
assert _unflatten_dict(data, key_map) == {
"model": {"layer.weight": tensor0},
"optimizer": {"state": {"layer.weight": {"exp_avg": {"test": tensor1}}}, "param_groups": "param_groups"},
}
3 changes: 1 addition & 2 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,8 +1013,7 @@ def _run_setup_assertions(empty_init, expected_device):
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))


# TODO: Support checkpoint consolidation with PyTorch >= 2.2
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", max_torch="2.2.0")
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")
def test_save_sharded_and_consolidate_and_load(tmp_path):
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""

Expand Down

0 comments on commit b3c869f

Please sign in to comment.