Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fabric always wraps the root module with FSDP #18005

Closed
carmocca opened this issue Jul 6, 2023 · 1 comment · Fixed by #18054
Closed

Fabric always wraps the root module with FSDP #18005

carmocca opened this issue Jul 6, 2023 · 1 comment · Fixed by #18054
Assignees
Labels
bug Something isn't working fabric lightning.fabric.Fabric strategy: fsdp Fully Sharded Data Parallel ver: 2.1.x
Milestone

Comments

@carmocca
Copy link
Contributor

carmocca commented Jul 6, 2023

Bug description

When the user is manually wrapping specific sections of the module with FSDP without using a wrapping policy, for example:

from torch.distributed.fsdp.wrap import wrap

l1 = torch.nn.Linear(1, 1)
l2 = torch.nn.Linear(1, 1)
model = torch.nn.Sequential(wrap(l1), torch.nn.ReLU(), wrap(l2))

Fabric still wraps the root module into FSDP.

This could defeat the memory savings expected by the user, from the FSDP docs:

[when wrapping the root without a policy] FSDP will put the entire model in one FSDP unit, which will reduce computation efficiency and memory efficiency.

This is not a problem with the Trainer because the user must wrap it inside configure_sharded_hook and the current behaviour is that FSDPStrategy.setup_module will be skipped if the hook is defined: https://github.com/Lightning-AI/lightning/blob/f4240ca42c75ad67b2655351b38830fa0ba82cba/src/lightning/pytorch/strategies/fsdp.py#L265-L271

What version are you seeing the problem on?

master

How to reproduce the bug

import torch
import lightning as L
from torch.distributed.fsdp.wrap import wrap
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

fabric = L.Fabric(devices=1, accelerator="cuda", strategy="fsdp")
fabric.launch()

with fabric.init_module():
    l1 = torch.nn.Linear(1, 1)
    l2 = torch.nn.Linear(1, 1)
    model = torch.nn.Sequential(wrap(l1), torch.nn.ReLU(), wrap(l2))

assert not isinstance(model, FullyShardedDataParallel)

"""
Sequential(
  (0): FullyShardedDataParallel(
    (_fsdp_wrapped_module): Linear(in_features=1, out_features=1, bias=True)
  )
  (1): ReLU()
  (2): FullyShardedDataParallel(
    (_fsdp_wrapped_module): Linear(in_features=1, out_features=1, bias=True)
  )
)
"""

# FAILS: assumes the root is FSDP wrapped
# fabric.save("/tmp/foo", {"model": model})

model = fabric.setup_module(model)

# AssertionError
assert not isinstance(model._forward_module, FullyShardedDataParallel)

"""
_FabricModule(
  (_forward_module): FullyShardedDataParallel(
    (_fsdp_wrapped_module): Sequential(
      (0): FullyShardedDataParallel(
        (_fsdp_wrapped_module): Linear(in_features=1, out_features=1, bias=True)
      )
      (1): ReLU()
      (2): FullyShardedDataParallel(
        (_fsdp_wrapped_module): Linear(in_features=1, out_features=1, bias=True)
      )
    )
  )
  (_original_module): Sequential(
    (0): FullyShardedDataParallel(
      (_fsdp_wrapped_module): Linear(in_features=1, out_features=1, bias=True)
    )
    (1): ReLU()
    (2): FullyShardedDataParallel(
      (_fsdp_wrapped_module): Linear(in_features=1, out_features=1, bias=True)
    )
  )
)
"""

Error messages and logs

No response

Environment

No response

More info

Discovered during #18004

The complex part of this bugfix is that current FSDP checkpointing logic assumes that the root is wrapped.

I found PyTorch tests with non-root FSDP modules: https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_fsdp.py#L438-L446

cc @carmocca @justusschock @awaelchli

@carmocca carmocca added bug Something isn't working fabric lightning.fabric.Fabric strategy: fsdp Fully Sharded Data Parallel labels Jul 6, 2023
@carmocca carmocca added this to the 2.1 milestone Jul 6, 2023
@carmocca carmocca self-assigned this Jul 6, 2023
@carmocca carmocca changed the title Fabric FSDP always wraps the root module with FSDP Fabric always wraps the root module with FSDP Jul 6, 2023
@RuABraun
Copy link

RuABraun commented Apr 4, 2024

So will fabric correctly handle and treat the unwrapped part of the model as when doing DDP, despite FSDP being the strategy selected?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working fabric lightning.fabric.Fabric strategy: fsdp Fully Sharded Data Parallel ver: 2.1.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants