Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 0 additions & 77 deletions docs/source-fabric/advanced/model_parallel/fsdp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -413,83 +413,6 @@ Advanced performance optimizations
If you’ve reached a good understanding of how the different FSDP settings impact the memory usage and speed of your model, here are a few more to squeeze out the last bit of performance.
These settings really depend on the specific use cases, so you will have to turn them on and off to see the impact on your model.

Overlap backward and optimizer’s step
=====================================

Fabric provides a context manager that allows you to overlap the backward and optimizer step to save significant memory and speed up the iteration time too.
By overlapping the two, we eliminate the need to store all gradients at once in memory.
Instead, the optimizer step updates are applied directly during backward as gradients become available, and the memory for gradients is immediately freed up.

Here is the recipe:

.. code-block:: python

# 1. Import the context manager
from lightning.fabric.strategies.fsdp import fsdp_overlap_step_with_backward

# 2. Create one optimizer instance per parameter
optimizers = [torch.optim.Adam([p], ...) for p in model.parameters()]
model, *optimizers = fabric.setup(model, *optimizers)

...

for i in range(max_iters):
loss = ...

# 3. Instead of calling `optimizer.step()`, call `fabric.backward(loss)`
# within the context manager
with fsdp_overlap_step_with_backward(optimizers, model):
fabric.backward(loss)

# optimizer.step()


.. collapse:: Full example

.. code-block:: python

import torch
import torch.nn as nn
import torch.nn.functional as F

import lightning as L
from lightning.fabric.strategies.fsdp import FSDPStrategy, fsdp_overlap_step_with_backward
from lightning.pytorch.demos import Transformer, WikiText2

policy = {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
strategy = FSDPStrategy(auto_wrap_policy=policy)
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()

fabric.seed_everything(42)

with fabric.rank_zero_first():
dataset = WikiText2()

# 1B parameters
model = Transformer(vocab_size=dataset.vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64)
optimizers = [torch.optim.Adam([p], lr=0.1) for p in model.parameters()]

model, *optimizers = fabric.setup(model, *optimizers)

for i in range(10):
input, target = fabric.to_device(dataset[i])
output = model(input.unsqueeze(0), target.unsqueeze(0))
loss = F.nll_loss(output, target.view(-1))

with fsdp_overlap_step_with_backward(optimizers, model):
fabric.backward(loss)
# no `optimizer.step()` here!

fabric.print(loss.item())

fabric.print(torch.cuda.memory_summary())

|

`Read the detailed blog post here <https://lightning.ai/pages/community/tutorial/faster-pytorch-training-by-reducing-peak-memory/>`_.
Note that this feature cannot work with gradient accumulation!


Disable foreach in the optimizer
================================
Expand Down
4 changes: 2 additions & 2 deletions docs/source-fabric/guide/checkpoint.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ Here's an example of using a filter when saving a checkpoint:

state = {"model": model, "optimizer": optimizer, "foo": 123}

# save only the model weights
filter = {"model": lambda k, v: "weight"}
# save only the weights that match a pattern
filter = {"model": lambda k, v: "weight" in k}
fabric.save("path/to/checkpoint.ckpt", state, filter=filter)
# This will save {"model": {"layer.weight": ...}, "optimizer": ..., "foo": 123}
# note that the optimizer params corresponding to the excluded model params are not filtered
Expand Down
126 changes: 1 addition & 125 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import threading
from contextlib import ExitStack, contextmanager
from contextlib import ExitStack
from datetime import timedelta
from functools import partial
from pathlib import Path
Expand All @@ -24,7 +23,6 @@
ContextManager,
Dict,
Generator,
Iterable,
List,
Literal,
Optional,
Expand Down Expand Up @@ -76,8 +74,6 @@
if TYPE_CHECKING:
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy

from lightning.fabric.wrappers import _FabricModule

if _TORCH_GREATER_EQUAL_2_0:
from torch.distributed.fsdp.wrap import ModuleWrapPolicy

Expand Down Expand Up @@ -892,123 +888,3 @@ def _has_meta_device_parameters(obj: Union[Module, Optimizer]) -> bool:
if isinstance(obj, Module):
return any(t.is_meta for t in obj.parameters())
raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}")


def _no_op() -> None:
pass


@contextmanager
def _apply_optimizers_during_fsdp_backward(
optimizers: Union[Optimizer, Iterable[Optimizer]],
module: Module,
) -> Generator[None, None, None]:
"""Call `Optimizer.step` as gradients become available.

NOTE: This is an EXPERIMENTAL utility and exploits behavior which is not
part of the FSDP public API. Use at your own risk.

By moving optimizer step invocation into the backward call we can free
gradients earlier and reduce peak memory.

"""
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_handles
from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle

apply_lock = threading.Lock()

param_handles = _get_fsdp_handles(module)
assert param_handles, f"Module {module} does not appear to contain any FSDP modules."
fsdp_state = _get_module_fsdp_state(module)
assert fsdp_state is not None
fsdp_stream = fsdp_state._post_backward_stream if _TORCH_GREATER_EQUAL_2_1 else fsdp_state._streams["post_backward"]

if isinstance(optimizers, Optimizer):
optimizers = [optimizers]

# We cannot trigger the optimizer step until all parameters are ready.
remaining = {}
for optimizer in optimizers:
unfinished: Dict[torch.nn.Parameter, None] = {} # Use Dict as an ordered set.
for group in optimizer.param_groups:
for p in group["params"]:
if p not in unfinished:
assert p not in remaining, f"{p=} is shared between two optimizers."
unfinished[p] = None
remaining[p] = (optimizer, unfinished)

def maybe_step(parameters: Iterable[torch.nn.Parameter], post_step: Callable[[], None] = _no_op) -> None:
for p in tuple(parameters):
optimizer, unfinished = remaining.pop(p)
unfinished.pop(p)
if not unfinished:
optimizer.step()
optimizer.zero_grad()

# Used to call `_reset_flat_param_grad_info_if_needed`. Otherwise FSDP might hold on to the memory.
post_step()

try:
hook_handles = []
for h in param_handles:
assert isinstance(h, FlatParamHandle)
flat_param = h.flat_param
fsdp_acc_grad, _ = flat_param._post_backward_hook_state # type: ignore

# We must take `h` and `flat_param` as arguments because Python
# late binds closures.
def _opt_hook(h: FlatParamHandle, flat_param: FlatParameter, *_unused: Any) -> None:
assert flat_param._post_backward_called
assert h.flat_param is flat_param
with apply_lock, torch.cuda.stream(fsdp_stream):
# We invoke `prepare_gradient_for_optim` earlier than usual.
# We also need to prevent the later "normal" invocation,
# otherwise the double call will trigger FSDP asserts.
prepare_gradient = h.prepare_gradient_for_optim
assert hasattr(prepare_gradient, "__func__"), prepare_gradient
assert prepare_gradient.__func__ is FlatParamHandle.prepare_gradient_for_optim
prepare_gradient()
h.prepare_gradient_for_optim = _no_op # type: ignore[method-assign]
post_step = (
h._reset_flat_param_grad_info_if_needed
if _TORCH_GREATER_EQUAL_2_1
else h._clear_grads_if_needed
)
maybe_step(flat_param._params or (), post_step=post_step)

hook = partial(_opt_hook, h, flat_param)
hook_handles.append(fsdp_acc_grad.register_hook(hook))

yield

finally:
# Non-FSDP parameters won't have a grad hook, so handle them here.
with apply_lock:
maybe_step(remaining)

# Unregister the grad hooks.
for hook_handle in hook_handles:
hook_handle.remove()

# And lastly back out the handle monkey patches.
for h in param_handles:
if h.prepare_gradient_for_optim is _no_op:
del h.prepare_gradient_for_optim


def fsdp_overlap_step_with_backward(
optimizers: Union[Optimizer, Iterable[Optimizer]],
fabric_module: "_FabricModule",
) -> Generator[None, None, None]:
if not _TORCH_GREATER_EQUAL_2_0:
raise NotImplementedError(
"`fsdp_overlap_step_with_backward` requires torch >= 2.0.0. HINT: `pip install -U torch`"
)

from lightning.fabric.wrappers import _FabricModule

assert isinstance(fabric_module, _FabricModule)
return _apply_optimizers_during_fsdp_backward( # type: ignore[return-value]
optimizers, fabric_module._forward_module
)
Loading