Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer.save_checkpoint() after Trainer.test() not working with FSDP #18971

Closed
awaelchli opened this issue Nov 8, 2023 · 0 comments · Fixed by #18992
Closed

Trainer.save_checkpoint() after Trainer.test() not working with FSDP #18971

awaelchli opened this issue Nov 8, 2023 · 0 comments · Fixed by #18992
Assignees
Labels
bug Something isn't working strategy: fsdp Fully Sharded Data Parallel ver: 2.1.x
Milestone

Comments

@awaelchli
Copy link
Member

awaelchli commented Nov 8, 2023

Bug description

See title.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

import lightning as L
import torch
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer, WikiText2
from torch.utils.data import DataLoader, random_split


class LanguageDataModule(L.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.vocab_size = 33278

    def prepare_data(self):
        WikiText2(download=True)

    def setup(self, stage):
        dataset = WikiText2()

        # Split data in to train, val, test
        n = len(dataset)
        self.train_dataset, self.val_dataset, self.test_dataset = random_split(dataset, [n - 4000, 2000, 2000])

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)


class LanguageModel(L.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size

    def configure_model(self):
        self.model = Transformer(vocab_size=self.vocab_size)

    def training_step(self, batch, batch_idx):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1)


def main():
    L.seed_everything(42)

    datamodule = LanguageDataModule(batch_size=20)

    model = LanguageModel(datamodule.vocab_size)

    # Trainer
    trainer = L.Trainer(max_epochs=1, strategy="fsdp")
    trainer.fit(model, datamodule=datamodule)
   #  trainer.save_checkpoint("ptl_fsdp.ckpt")  # <-------------  here works
    trainer.test(model, datamodule=datamodule)

    # trainer.save_checkpoint("ptl_fsdp.ckpt")  # <------------- here doesn't work


if __name__ == "__main__":
    main()

Error messages and logs

Traceback (most recent call last):
  File "/teamspace/studios/this_studio/main.py", line 84, in <module>
    main()
  File "/teamspace/studios/this_studio/main.py", line 80, in main
    trainer.save_checkpoint("ptl_fsdp.ckpt")
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1382, in save_checkpoint
    checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 445, in dump_checkpoint
    optimizer_state = trainer.strategy.optimizer_state(optimizer)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/strategies/fsdp.py", line 503, in optimizer_state
    state_dict = FSDP.optim_state_dict(self.model, optimizer)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1818, in optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_impl(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1246, in _optim_state_dict_impl
    return _optim_state_dict(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1354, in _optim_state_dict
    all_optim_state_keys, optim_state_key_to_param_key = _map_param_key_to_optim_keys(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1184, in _map_param_key_to_optim_keys
    fqns = param_to_fqns[param]
KeyError: Parameter containing:
tensor([-0.0283, -0.0378, -0.0333,  ...,  0.0511, -0.0303, -0.0215],
       device='cuda:3', requires_grad=True)
Traceback (most recent call last):
  File "/teamspace/studios/this_studio/main.py", line 84, in <module>
    main()
  File "/teamspace/studios/this_studio/main.py", line 80, in main
    trainer.save_checkpoint("ptl_fsdp.ckpt")
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1382, in save_checkpoint
    checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 445, in dump_checkpoint
    optimizer_state = trainer.strategy.optimizer_state(optimizer)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/strategies/fsdp.py", line 503, in optimizer_state
    state_dict = FSDP.optim_state_dict(self.model, optimizer)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1818, in optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_impl(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1246, in _optim_state_dict_impl
    return _optim_state_dict(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1354, in _optim_state_dict
    all_optim_state_keys, optim_state_key_to_param_key = _map_param_key_to_optim_keys(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1184, in _map_param_key_to_optim_keys
    fqns = param_to_fqns[param]
KeyError: Parameter containing:
tensor([-0.2867,  0.0361, -0.6431,  ..., -2.3856,  0.9756, -1.5845],
       device='cuda:1', requires_grad=True)
Traceback (most recent call last):
  File "/teamspace/studios/this_studio/main.py", line 84, in <module>
    main()
  File "/teamspace/studios/this_studio/main.py", line 80, in main
    trainer.save_checkpoint("ptl_fsdp.ckpt")
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1382, in save_checkpoint
    checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 445, in dump_checkpoint
    optimizer_state = trainer.strategy.optimizer_state(optimizer)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/strategies/fsdp.py", line 503, in optimizer_state
    state_dict = FSDP.optim_state_dict(self.model, optimizer)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1818, in optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_impl(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1246, in _optim_state_dict_impl
    return _optim_state_dict(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1354, in _optim_state_dict
    all_optim_state_keys, optim_state_key_to_param_key = _map_param_key_to_optim_keys(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1184, in _map_param_key_to_optim_keys
    fqns = param_to_fqns[param]
KeyError: Parameter containing:
tensor([-0.0150, -0.0183, -0.0415,  ..., -0.0789, -0.0992,  0.0360],
       device='cuda:2', requires_grad=True)
Traceback (most recent call last):
  File "/teamspace/studios/this_studio/main.py", line 84, in <module>
    main()
  File "/teamspace/studios/this_studio/main.py", line 80, in main
    trainer.save_checkpoint("ptl_fsdp.ckpt")
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1382, in save_checkpoint
    checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 445, in dump_checkpoint
    optimizer_state = trainer.strategy.optimizer_state(optimizer)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/strategies/fsdp.py", line 503, in optimizer_state
    state_dict = FSDP.optim_state_dict(self.model, optimizer)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1818, in optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_impl(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1246, in _optim_state_dict_impl
    return _optim_state_dict(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1354, in _optim_state_dict
    all_optim_state_keys, optim_state_key_to_param_key = _map_param_key_to_optim_keys(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1184, in _map_param_key_to_optim_keys
    fqns = param_to_fqns[param]
KeyError: Parameter containing:
tensor([-0.0242,  1.0170, -0.4915,  ..., -0.5053, -0.3954,  0.2465],
       device='cuda:0', requires_grad=True)

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): 2.1.0
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0): 2.1.0
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @awaelchli @carmocca

@awaelchli awaelchli added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Nov 8, 2023
@awaelchli awaelchli added strategy: fsdp Fully Sharded Data Parallel and removed needs triage Waiting to be triaged by maintainers labels Nov 8, 2023
@awaelchli awaelchli added this to the 2.1.x milestone Nov 8, 2023
@awaelchli awaelchli self-assigned this Nov 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working strategy: fsdp Fully Sharded Data Parallel ver: 2.1.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant