Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/source/checkpoints/dist_ckpt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Basic Sharding
--------------

The main way to define the relationship of a plain, local PyTorch tensor to tensors on other ranks is by wrapping it in a ``ShardedTensor`` class.
This allows to express the fact that a given local tensor is part of a larger *grid* of tensors of a given shape at a given offset.
This expresses that a given local tensor is part of a larger *grid* of tensors of a given shape at a given offset.
Instead of saving a simple state dict with ``torch.Tensor``, we save a *sharded* state dict with ``dist_checkpointing.ShardedTensor``.

Example: assume we have a tensor (composed of 128 elements) divided equally across the whole workload which we want to save and load with different number of ranks.
Expand Down Expand Up @@ -170,7 +170,7 @@ Example: assume we have a tensor (composed of 128 elements) divided equally acro
dist_checkpointing.save(sharded_state_dict, dist_ckpt_root)

During load, the distributed checkpoint can be easily read even if the job size changes (contrary to native checkpoints that require the same number of ranks).
The main difference with wrt. ``torch.load`` is that the user has to provide the definition of the sharded state dict that needs to be loaded.
The main difference with with respect to ``torch.load`` is that the user has to provide the definition of the sharded state dict that needs to be loaded.

.. code-block:: python

Expand Down Expand Up @@ -296,21 +296,21 @@ There are several useful user entry points for checkpoint saving and loading.
dist_checkpointing.save
^^^^^^^^^^^^^^^^^^^^^^^
The ``dist_checkpointing.save`` function is the only entry point for checkpoint saving.
It requires providing a sharded state dict to save and saving strategies for handling different entities (see `Save and load strategies`_ for detailed explanation).
It requires a sharded state dict to save and saving strategies for handling different entities (see `Save and load strategies`_ for detailed explanation).
The sharded state dict is processed in the following way (see also ``save`` function `documentation <https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/dist_checkpointing.html#module-core.dist_checkpointing.serialization>`_):

1. The ShardedTensorFactories are applied.
2. The LocalNonPersistentObjects are extracted from the sharded state dict and ignored.
3. The ShardedBase objects are extracted.
4. All other objects are treated as "common" and saved according to a sharded strategy (see `Save and load strategies`_).
5. All ShardedObjects are extracted from point (3) objects and saved with a common strategy (see `Save and load strategies`_).
4. All other objects are treated as "common" and saved according to a common strategy (see `Save and load strategies`_).
5. All ShardedObjects are extracted from point (3) objects and saved with a sharded strategy (see `Save and load strategies`_).
6. All ShardedTensors are saved.
7. The ``metadata.json`` file with backend and version metadata is saved to the checkpoint directory.

dist_checkpointing.load
^^^^^^^^^^^^^^^^^^^^^^^
The ``dist_checkpointing.load`` function is the main entry point for checkpoint loading.
It requires providing a sharded state dict (in order to implicitly define mappings between local tensors and checkpoint tensors) and loading strategies.
It requires a sharded state dict (in order to implicitly define mappings between local tensors and checkpoint tensors) and loading strategies.
In practice, the same sharded state dict can be usually used for both saving and loading (the sharded state dict for loading will just contain tensors with uninitialized data).

When the sharded state dict is provided as input, it is processed in the following way (see also ``load`` function `documentation <https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/dist_checkpointing.html#module-core.dist_checkpointing.serialization>`_):
Expand Down Expand Up @@ -404,7 +404,7 @@ For "common" strategies, currently the only supported one is ``torch`` which sav
PyTorch Distributed
^^^^^^^^^^^^^^^^^^^
The PyTorch Distributed based checkpoint format uses the ``torch.distributed.checkpoint`` package in order to serialize the checkpoints to storage.
The Megatron Core sharded state dicts are translated into ``torch.distributed.SharedTensor`` and then ``torch.distributed.checkpoint`` primitives are used to serialize such state dicts.
The Megatron Core sharded state dicts are translated into ``torch.distributed.ShardedTensor`` and then ``torch.distributed.checkpoint`` primitives are used to serialize such state dicts.
Even though Megatron Core provides several saving optimizations, the underlying checkpoint can still be read with native `PyTorch loading methods <https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict_loader.load>`_.
Note that the checkpoint still follows the ``dist_checkpointing`` package format by providing additional ``common.pt`` and ``metadata.json`` files described above.

Expand All @@ -425,7 +425,7 @@ This mapping can be used by the ``dist_checkpointing.optimizers.optim_state_to_s
This should support most optimizer cases, but some of them might require custom sharded state dict creation.
A good example is a Distributed Optimizer which flattens the parameters - see `Tensors transformations`_ section for more details.

Note: In order to reuse model SharderTensors to create optimizer ShardedTensors, the model **SharderTensors must wrap model parameters**, not just tensors
Note: In order to reuse model ShardedTensors to create optimizer ShardedTensors, the model **ShardedTensors must wrap model parameters**, not just tensors
(obtaining a state dict with model parameters can be achieved by passing ``keep_vars=True`` to the model ``state_dict`` function).
Otherwise the correspondence between model ShardedTensors and optimizer states is impossible to recreate.
This is the reason for introducing ShardedTensorFactories - we have to register the original model parameter as ``ShardedTensorFactories.data`` and apply any subsequent transformations as a factory function in order to make sure that the same transformation can be applied to the optimizer states.
Expand Down
Loading