From b38250c9b14b140bebf034fba15d20f577584d27 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 5 Oct 2023 17:07:33 +0200 Subject: [PATCH 1/3] mini doc update --- docs/source-fabric/guide/checkpoint.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source-fabric/guide/checkpoint.rst b/docs/source-fabric/guide/checkpoint.rst index 35a8c0b85b033..116f68fc4b68c 100644 --- a/docs/source-fabric/guide/checkpoint.rst +++ b/docs/source-fabric/guide/checkpoint.rst @@ -171,8 +171,8 @@ Here's an example of using a filter when saving a checkpoint: state = {"model": model, "optimizer": optimizer, "foo": 123} - # save only the model weights - filter = {"model": lambda k, v: "weight"} + # save only the weights that match a pattern + filter = {"model": lambda k, v: "weight" in k} fabric.save("path/to/checkpoint.ckpt", state, filter=filter) # This will save {"model": {"layer.weight": ...}, "optimizer": ..., "foo": 123} # note that the optimizer params corresponding to the excluded model params are not filtered From 68721595becf1212d5ca76fd32caf91e1ba102cb Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 5 Oct 2023 17:13:44 +0200 Subject: [PATCH 2/3] remove --- .../advanced/model_parallel/fsdp.rst | 77 ------- src/lightning/fabric/strategies/fsdp.py | 125 +---------- tests/tests_fabric/strategies/test_fsdp.py | 200 ------------------ 3 files changed, 1 insertion(+), 401 deletions(-) diff --git a/docs/source-fabric/advanced/model_parallel/fsdp.rst b/docs/source-fabric/advanced/model_parallel/fsdp.rst index 8f55b282df8fa..6707f9834d698 100644 --- a/docs/source-fabric/advanced/model_parallel/fsdp.rst +++ b/docs/source-fabric/advanced/model_parallel/fsdp.rst @@ -413,83 +413,6 @@ Advanced performance optimizations If you’ve reached a good understanding of how the different FSDP settings impact the memory usage and speed of your model, here are a few more to squeeze out the last bit of performance. These settings really depend on the specific use cases, so you will have to turn them on and off to see the impact on your model. -Overlap backward and optimizer’s step -===================================== - -Fabric provides a context manager that allows you to overlap the backward and optimizer step to save significant memory and speed up the iteration time too. -By overlapping the two, we eliminate the need to store all gradients at once in memory. -Instead, the optimizer step updates are applied directly during backward as gradients become available, and the memory for gradients is immediately freed up. - -Here is the recipe: - -.. code-block:: python - - # 1. Import the context manager - from lightning.fabric.strategies.fsdp import fsdp_overlap_step_with_backward - - # 2. Create one optimizer instance per parameter - optimizers = [torch.optim.Adam([p], ...) for p in model.parameters()] - model, *optimizers = fabric.setup(model, *optimizers) - - ... - - for i in range(max_iters): - loss = ... - - # 3. Instead of calling `optimizer.step()`, call `fabric.backward(loss)` - # within the context manager - with fsdp_overlap_step_with_backward(optimizers, model): - fabric.backward(loss) - - # optimizer.step() - - -.. collapse:: Full example - - .. code-block:: python - - import torch - import torch.nn as nn - import torch.nn.functional as F - - import lightning as L - from lightning.fabric.strategies.fsdp import FSDPStrategy, fsdp_overlap_step_with_backward - from lightning.pytorch.demos import Transformer, WikiText2 - - policy = {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer} - strategy = FSDPStrategy(auto_wrap_policy=policy) - fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy) - fabric.launch() - - fabric.seed_everything(42) - - with fabric.rank_zero_first(): - dataset = WikiText2() - - # 1B parameters - model = Transformer(vocab_size=dataset.vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64) - optimizers = [torch.optim.Adam([p], lr=0.1) for p in model.parameters()] - - model, *optimizers = fabric.setup(model, *optimizers) - - for i in range(10): - input, target = fabric.to_device(dataset[i]) - output = model(input.unsqueeze(0), target.unsqueeze(0)) - loss = F.nll_loss(output, target.view(-1)) - - with fsdp_overlap_step_with_backward(optimizers, model): - fabric.backward(loss) - # no `optimizer.step()` here! - - fabric.print(loss.item()) - - fabric.print(torch.cuda.memory_summary()) - -| - -`Read the detailed blog post here `_. -Note that this feature cannot work with gradient accumulation! - Disable foreach in the optimizer ================================ diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index bf80b62b248fa..86e9f68b05794 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import threading -from contextlib import ExitStack, contextmanager +from contextlib import ExitStack from datetime import timedelta from functools import partial from pathlib import Path @@ -76,8 +75,6 @@ if TYPE_CHECKING: from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy - from lightning.fabric.wrappers import _FabricModule - if _TORCH_GREATER_EQUAL_2_0: from torch.distributed.fsdp.wrap import ModuleWrapPolicy @@ -892,123 +889,3 @@ def _has_meta_device_parameters(obj: Union[Module, Optimizer]) -> bool: if isinstance(obj, Module): return any(t.is_meta for t in obj.parameters()) raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}") - - -def _no_op() -> None: - pass - - -@contextmanager -def _apply_optimizers_during_fsdp_backward( - optimizers: Union[Optimizer, Iterable[Optimizer]], - module: Module, -) -> Generator[None, None, None]: - """Call `Optimizer.step` as gradients become available. - - NOTE: This is an EXPERIMENTAL utility and exploits behavior which is not - part of the FSDP public API. Use at your own risk. - - By moving optimizer step invocation into the backward call we can free - gradients earlier and reduce peak memory. - - """ - from torch.distributed.fsdp._common_utils import _get_module_fsdp_state - from torch.distributed.fsdp._traversal_utils import _get_fsdp_handles - from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle - - apply_lock = threading.Lock() - - param_handles = _get_fsdp_handles(module) - assert param_handles, f"Module {module} does not appear to contain any FSDP modules." - fsdp_state = _get_module_fsdp_state(module) - assert fsdp_state is not None - fsdp_stream = fsdp_state._post_backward_stream if _TORCH_GREATER_EQUAL_2_1 else fsdp_state._streams["post_backward"] - - if isinstance(optimizers, Optimizer): - optimizers = [optimizers] - - # We cannot trigger the optimizer step until all parameters are ready. - remaining = {} - for optimizer in optimizers: - unfinished: Dict[torch.nn.Parameter, None] = {} # Use Dict as an ordered set. - for group in optimizer.param_groups: - for p in group["params"]: - if p not in unfinished: - assert p not in remaining, f"{p=} is shared between two optimizers." - unfinished[p] = None - remaining[p] = (optimizer, unfinished) - - def maybe_step(parameters: Iterable[torch.nn.Parameter], post_step: Callable[[], None] = _no_op) -> None: - for p in tuple(parameters): - optimizer, unfinished = remaining.pop(p) - unfinished.pop(p) - if not unfinished: - optimizer.step() - optimizer.zero_grad() - - # Used to call `_reset_flat_param_grad_info_if_needed`. Otherwise FSDP might hold on to the memory. - post_step() - - try: - hook_handles = [] - for h in param_handles: - assert isinstance(h, FlatParamHandle) - flat_param = h.flat_param - fsdp_acc_grad, _ = flat_param._post_backward_hook_state # type: ignore - - # We must take `h` and `flat_param` as arguments because Python - # late binds closures. - def _opt_hook(h: FlatParamHandle, flat_param: FlatParameter, *_unused: Any) -> None: - assert flat_param._post_backward_called - assert h.flat_param is flat_param - with apply_lock, torch.cuda.stream(fsdp_stream): - # We invoke `prepare_gradient_for_optim` earlier than usual. - # We also need to prevent the later "normal" invocation, - # otherwise the double call will trigger FSDP asserts. - prepare_gradient = h.prepare_gradient_for_optim - assert hasattr(prepare_gradient, "__func__"), prepare_gradient - assert prepare_gradient.__func__ is FlatParamHandle.prepare_gradient_for_optim - prepare_gradient() - h.prepare_gradient_for_optim = _no_op # type: ignore[method-assign] - post_step = ( - h._reset_flat_param_grad_info_if_needed - if _TORCH_GREATER_EQUAL_2_1 - else h._clear_grads_if_needed - ) - maybe_step(flat_param._params or (), post_step=post_step) - - hook = partial(_opt_hook, h, flat_param) - hook_handles.append(fsdp_acc_grad.register_hook(hook)) - - yield - - finally: - # Non-FSDP parameters won't have a grad hook, so handle them here. - with apply_lock: - maybe_step(remaining) - - # Unregister the grad hooks. - for hook_handle in hook_handles: - hook_handle.remove() - - # And lastly back out the handle monkey patches. - for h in param_handles: - if h.prepare_gradient_for_optim is _no_op: - del h.prepare_gradient_for_optim - - -def fsdp_overlap_step_with_backward( - optimizers: Union[Optimizer, Iterable[Optimizer]], - fabric_module: "_FabricModule", -) -> Generator[None, None, None]: - if not _TORCH_GREATER_EQUAL_2_0: - raise NotImplementedError( - "`fsdp_overlap_step_with_backward` requires torch >= 2.0.0. HINT: `pip install -U torch`" - ) - - from lightning.fabric.wrappers import _FabricModule - - assert isinstance(fabric_module, _FabricModule) - return _apply_optimizers_during_fsdp_backward( # type: ignore[return-value] - optimizers, fabric_module._forward_module - ) diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index a398297156a51..f6f5abacc2c99 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -11,9 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import datetime -import os from datetime import timedelta from re import escape from unittest import mock @@ -23,7 +20,6 @@ import pytest import torch import torch.nn as nn -from lightning.fabric import Fabric from lightning.fabric.plugins import HalfPrecision from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import FSDPStrategy @@ -31,10 +27,8 @@ _FSDPBackwardSyncControl, _get_full_state_dict_context, _has_meta_device_parameters, - fsdp_overlap_step_with_backward, ) from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 -from lightning_utilities.core.imports import RequirementCache from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision from torch.optim import Adam @@ -450,197 +444,3 @@ def test_get_full_state_dict_context_offload(set_type_mock, monkeypatch, torch_g with _get_full_state_dict_context(module=Mock(spec=FullyShardedDataParallel), world_size=4): assert set_type_mock.call_args_list[0][0][2].offload_to_cpu # model config assert set_type_mock.call_args_list[0][0][3].offload_to_cpu # optim config - - -class SubBlock(nn.Sequential): - def __init__(self, feature_dim: int) -> None: - super().__init__( - nn.Linear(feature_dim, feature_dim, bias=False), - torch.nn.LayerNorm([feature_dim]), - nn.ReLU(), - ) - - -class Block(nn.Module): - def __init__(self, feature_dim: int) -> None: - super().__init__() - self.left = SubBlock(feature_dim) - self.right = SubBlock(feature_dim) - - def forward(self, x): - return self.left(x) + self.right(x) - - -class StatusChecker: - def __init__(self, fabric: Fabric) -> None: - self._fabric = fabric - self.is_rank_zero = fabric.is_global_zero - self.pids = tuple(int(pid) for pid in fabric.all_gather(os.getpid()).cpu().numpy()) - - @contextlib.contextmanager - def guard_region(self, name: str): - """Handle errors and graceful shutdown. - - `pytest` interprets SystemExit as a faiure, so it will interpret shutdown of non-zero ranks as a test failure. - This is confusing (since it logs "FAILED"), but more importantly the orphan rank will continue trying to execute - the rest of the test suite. So instead we add calls to `os._exit` which actually forces the process to shut - down. - - """ - success = False - try: - yield - success = True - - except BaseException: - if self.is_rank_zero: - raise - - finally: - # All reduce will wait for all workers to enter. This means that if a - # worker dies the status check will deadlock. - import psutil - - worker_status = tuple(psutil.Process(pid).status() for pid in self.pids) - if any( - status in (psutil.STATUS_DEAD, psutil.STATUS_STOPPED, psutil.STATUS_ZOMBIE) for status in worker_status - ): - if self.is_rank_zero: - raise RuntimeError(f"({name}) Dead workers: [{', '.join(worker_status)}]") - else: - os._exit(1) - - rank_success = self._fabric.all_gather(success).cpu() - if not rank_success.all(): - if self.is_rank_zero > 0: - os._exit(1) - elif success: - raise RuntimeError(f"({name}) Failure on different rank: {rank_success}") - - def finalize(self) -> None: - if not self.is_rank_zero: - os._exit(0) - - def __del__(self) -> None: - self.finalize() - - -@pytest.mark.xfail(strict=False, reason="Flaky test") # See also: https://github.com/Lightning-AI/lightning/pull/17774 -@RunIf(min_torch="2.0.0", min_cuda_gpus=2, skip_windows=True, standalone=True) -@pytest.mark.skipif(not RequirementCache("psutil"), reason="psutil is needed to help prevent deadlocks.") -@pytest.mark.parametrize( - "checkpoint", - [(Block,), (SubBlock,), (Block, SubBlock, nn.Linear), None], -) -def test_apply_optimizer_in_backward(checkpoint): - from torch.distributed.fsdp._traversal_utils import _get_fsdp_handles - - num_gpus = 2 - num_blocks = 8 - feature_dim = 256 - - # This bound is dependent on the topology of the model. The grads for each - # Block are two `feature_dim ** 2` Tensors (`left` and `right` Linear layers) - # times four. (FP32 = 4 bytes / element) - # - # In the baseline case grads for all Blocks are materialized at once, whereas - # in the test case only one Block should have grads in memory which adds a - # `(num_blocks - 1)` factor. - # - # However, there is one final correction to be made. In the baseline case peak - # memory occurs at the end of the backward pass; at that time activations will - # have been freed and will offset the memory relative to the base case. (Which - # reaches peak memory after the first block when most activations are still - # in memory.) It's difficult to estimate the exact correction factor - # (particularly since it varies with activation checkpointing strategy), but - # three is close enough for our purposes. - upper_savings_bound = 4 * feature_dim**2 * 2 * (num_blocks - 1) - lower_savings_bound = upper_savings_bound / 3 - - strategy = FSDPStrategy( - auto_wrap_policy={Block}, - activation_checkpointing=checkpoint, - timeout=datetime.timedelta(seconds=10), - ) - fabric = Fabric(accelerator="cuda", devices=num_gpus, strategy=strategy) - fabric.launch() - status_checker = StatusChecker(fabric) - - def make_model_and_optimizers(): - torch.manual_seed(0) - - with fabric.init_module(): - backbone = [Block(feature_dim) for _ in range(num_blocks)] - model = nn.Sequential(*backbone, nn.Linear(feature_dim, 1, bias=False)) - optimizers = [torch.optim.SGD(layer.parameters(), lr=0.1, momentum=0.9) for layer in model] - - return fabric.setup_module(model), fabric.setup_optimizers(*optimizers) - - with status_checker.guard_region("Instantiate model."): - baseline_model, baseline_optimizers = make_model_and_optimizers() - test_model, test_optimizers = make_model_and_optimizers() - fabric.seed_everything(1337 + fabric.global_rank) - - # Check that initialization is identical. - for p0, p1 in zip(baseline_model.parameters(), test_model.parameters()): - assert (p0 == p1).all() - - num_steps = 50 - for step in range(num_steps): - # Normal pattern: `.backward()` followed by `.step()` - with status_checker.guard_region(f"({step + 1} / {num_steps}) Baseline"): - x = torch.randn((4, feature_dim), device=fabric.device) - y_baseline = baseline_model(x) - - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats(fabric.device) - baseline_start_memory = torch.cuda.max_memory_allocated(fabric.device) - fabric.backward(y_baseline.mean().abs()) - del y_baseline - for optimizer in baseline_optimizers: - optimizer.step() - optimizer.zero_grad(set_to_none=True) - - # FSDP sometimes holds onto grad memory until the next forward - # pass. In order to provide a fair comparison (and thus an - # accurate check that moving the step call into backward actually - # delivers the expected memory savings) we need to "help" the - # baseline case a bit. - param_handles = _get_fsdp_handles(baseline_model._forward_module) - for h in param_handles: - if _TORCH_GREATER_EQUAL_2_1: - h._reset_flat_param_grad_info_if_needed() - else: - h._clear_grads_if_needed() - - baseline_peak_memory = torch.cuda.max_memory_allocated(fabric.device) - - # `.step()` interleaved with `.backward()` - with status_checker.guard_region(f"({step + 1} / {num_steps}) Optimizer in backward"): - y_test = test_model(x) - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats(fabric.device) - test_start_memory = torch.cuda.memory_allocated(fabric.device) - with fsdp_overlap_step_with_backward(test_optimizers, test_model): - fabric.backward(y_test.mean().abs()) - del y_test - - test_peak_memory = torch.cuda.max_memory_allocated(fabric.device) - - # Make sure the parameter updates match. - with status_checker.guard_region(f"({step + 1} / {num_steps}) Check equality"): - for idx, (p0, p1) in enumerate(zip(baseline_model.parameters(), test_model.parameters())): - assert (p0 == p1).all(), (step, idx, p0, p1) - - # The first step is going to be odd due to lazy initialization of optimizer state. - if not step: - continue - - with status_checker.guard_region(f"({step + 1} / {num_steps}) Confirm memory reduction"): - baseline_delta = baseline_peak_memory - baseline_start_memory - test_delta = test_peak_memory - test_start_memory - assert (baseline_delta - test_delta) >= lower_savings_bound, (baseline_delta, test_delta) - assert (baseline_delta - test_delta) <= upper_savings_bound, (baseline_delta, test_delta) - - status_checker.finalize() - assert (pid := os.getpid()) == status_checker.pids[0], f"Orphan worker: {pid}" From 0f7bef93f4e37c8917a1f5c120f768e913710af9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 15:17:19 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/strategies/fsdp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 86e9f68b05794..51226250c2ba1 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -23,7 +23,6 @@ ContextManager, Dict, Generator, - Iterable, List, Literal, Optional,