Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot save and load checkpoints with FSDPStrategy #17702

Closed
nlpTRIZ opened this issue May 26, 2023 · 9 comments · Fixed by #17819
Closed

Cannot save and load checkpoints with FSDPStrategy #17702

nlpTRIZ opened this issue May 26, 2023 · 9 comments · Fixed by #17819
Labels
bug Something isn't working strategy: fsdp Fully Sharded Data Parallel ver: 2.1.x

Comments

@nlpTRIZ
Copy link

nlpTRIZ commented May 26, 2023

Bug description

Hello,
I must be missing something but when using FSDPStrategy with 2 gpus and the following code I encounter several problems:

  • only rank 0 saves model
  • parameters not mapped correctly in state dict
  • ModelSummary gives info for rank 0 only (I guess this one is normal)

Is it possible to save a full state dict with FSDP (and be able to load the model afterwards on a different number of gpus) ?
Thank you for your help

What version are you seeing the problem on?

master

How to reproduce the bug

import os

import torch
from torch.utils.data import Dataset
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.strategies import FSDPStrategy


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def on_predict_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"y": loss}

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.trainer.model.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]


def run_test():
    # fake data
    train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
    val_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
    test_data = torch.utils.data.DataLoader(RandomDataset(32, 64))

    # model
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        max_epochs=1,
        strategy=FSDPStrategy(
            cpu_offload=True,
            use_orig_params=True,
        )
    )
    trainer.fit(model, train_data, val_data)

    # Try to load with hparams
    ckpt_path = trainer.checkpoint_callback.best_model_path
    ckpt = torch.load(ckpt_path)
    print(ckpt)
    model = BoringModel.load_from_checkpoint(ckpt_path)


if __name__ == '__main__':
    run_test()

Error messages and logs

RuntimeError: Error(s) in loading state_dict for BoringModel:
        size mismatch for layer.weight: copying a param with shape torch.Size([33]) from checkpoint, the shape in current model is torch.Size([2, 32]).
        size mismatch for layer.bias: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([2]).

Environment

Current environment
#- PyTorch Lightning Version : 2.0
#- PyTorch Version: 2.0
#- Python version: 3.10
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration: 2 A40

More info

No response

cc @awaelchli @carmocca

@nlpTRIZ nlpTRIZ added bug Something isn't working needs triage Waiting to be triaged by maintainers labels May 26, 2023
@leng-yue
Copy link
Contributor

leng-yue commented May 28, 2023

I think saving full state dict only on rank 0 is the proper behavior (though fabric provides an option to switch between full state and shared state). This issue is also mentioned in #16815 , however, I can't reproduce it now....

@nlpTRIZ
Copy link
Author

nlpTRIZ commented May 28, 2023

The problem I encounter is that only the first shard is saved which is useless. Is the issue solved on your side ? (Is that what you mean by "can't reproduce it")

@leng-yue
Copy link
Contributor

I tested your code, and it can't reproduce the error. It seems like the weights are stored properly...

@nlpTRIZ
Copy link
Author

nlpTRIZ commented May 28, 2023

Which version of lightning are you using ?

@leng-yue
Copy link
Contributor

leng-yue commented May 28, 2023

master, tried on both 2x3090 and 8xv100

@nlpTRIZ
Copy link
Author

nlpTRIZ commented May 28, 2023

Considering the implementation of lightning_module_state_dict() in FSDPStrategy it was normal behavior, changing rank0_only to False made it work as I wanted it to.

class CustomFSDP(FSDPStrategy):
    def lightning_module_state_dict(self) -> Dict[str, Any]:
        """Gathers the full state dict by unsharding all the parameters.
        To avoid OOM, the returned parameters will only be returned on rank 0 and on CPU. All other ranks get an empty
        dict.
        """
        assert self.model is not None

        with FullyShardedDataParallel.state_dict_type(
            module=self.model,
            state_dict_type=StateDictType.FULL_STATE_DICT,
            state_dict_config=FullStateDictConfig(offload_to_cpu=(self.world_size > 1), rank0_only=False),
        ):
            state_dict = self.model.state_dict()
            return _strip_prefix_from_state_dict(state_dict, prefix="_forward_module.")

@leng-yue
Copy link
Contributor

Why would you like to move the full state dict to all ranks? It's not a shared state dict.

@nlpTRIZ
Copy link
Author

nlpTRIZ commented May 29, 2023

In this particular case yes. But with my config (I tried several releases + master) by using the default FSDP implementation it only saved N/k parameters (N number of parameters of the model and k number of shards), so only shard 0 which is not useable. You get several checkpoints, one for each shard, with my first code?

@leng-yue
Copy link
Contributor

So, in your case, it's not recommended to use state_dict_type=StateDictType.FULL_STATE_DICT. It's unclear if there is a proper way to handle the sharded dictionary in Lightning (although it is supported in Fabric).

@awaelchli awaelchli added strategy: fsdp Fully Sharded Data Parallel and removed needs triage Waiting to be triaged by maintainers labels Jun 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working strategy: fsdp Fully Sharded Data Parallel ver: 2.1.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants