Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper way to checkpoint model using FSDP #17798

Closed
wj210 opened this issue Jun 9, 2023 · 12 comments · Fixed by #17819
Closed

Proper way to checkpoint model using FSDP #17798

wj210 opened this issue Jun 9, 2023 · 12 comments · Fixed by #17819
Labels
question Further information is requested strategy: fsdp Fully Sharded Data Parallel ver: 2.0.x

Comments

@wj210
Copy link

wj210 commented Jun 9, 2023

Bug description

I am trying to save the state_dict of a fine-tuned t5 model from huggingface, however i was unable to load from the checkpoint and checked that the state_dict in the checkpoint was an empty dict.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

`class T5FineTuner(pl.LightningModule):
    def __init__(self, hparams):
        super(T5FineTuner, self).__init__()
        self.save_hyperparameters(hparams)
        
        self.model = T5ForConditionalGeneration.from_pretrained(hparams.model_name_or_path)

model = T5FineTuner(args)
early_stop_callback = pl.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=3,
        verbose=False,
        mode='min')
checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath='model_checkpoints/',
        filename=args.model_name_or_path,  # your custom filename
        save_weights_only=True,
        monitor='val_loss',
        mode='min',
        save_top_k=1,
        )
callbacks = [LoggingCallback(logger = logger,task = args.task,file_path = output_num_file,args=args),early_stop_callback,checkpoint_callback]

trainer = trainer = pl.Trainer(strategy = 'fsdp',callbacks = callbacks,enable_checkpointing=True)`

Error messages and logs

The state_dict was basically empty.

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

How can I properly save the model using fsdp? I tried to save it manually using torch.save(self.model.state_dict) and it was empty as well.
However when I am not using distributed training, it does not have the same issue.

cc @awaelchli @carmocca

@wj210 wj210 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jun 9, 2023
@awaelchli
Copy link
Member

Hi @wj210
You can't just save a FSDP model with a manual torch.save. You would have to add a lot of boilerplate code from PyTorch to get this right.

I suggest you don't do that, since Lightning can already do it for you. With Trainer(enable_checkpointing=True), the trainer will already save checkpoints to the logging directory. Furthermore, you can trigger a manual save using trainer.save_checkpiont(...) yourself.
Hope this helps.

@awaelchli awaelchli added question Further information is requested and removed needs triage Waiting to be triaged by maintainers bug Something isn't working labels Jun 10, 2023
@awaelchli
Copy link
Member

awaelchli commented Jun 10, 2023

One more remark:
For FSDP to work properly (and it is still experimental and incomplete!), you need to install lightning from source.

pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U

@awaelchli awaelchli added the strategy: fsdp Fully Sharded Data Parallel label Jun 10, 2023
@villmow
Copy link

villmow commented Jun 12, 2023

Hi, I have the same problem. I train a model with Trainer(strategy="fsdp", accelerator="gpu", devices=8, precision="bf16-mixed", ...) and add checkpointing with ModelCheckpoint callback. I don't save checkpoints manually.

Training works fine and checkpoints are saved, but I can't load the checkpoints. The checkpoints seem to have the correct size, but the state_dict is empty, when I inspect them with torch.load.

I just installed lightning from source and use torch 2.1.0.dev20230327 .

Do you have any suggestions?

@wj210
Copy link
Author

wj210 commented Jun 12, 2023

My issue is similar to @villmow. What I meant by torch.save was an alternative I tried after discovering that the automatically checkpointing produces a checkpoint with empty state dict. I have tried to install from source but the response is the same.

@leng-yue
Copy link
Contributor

One more remark: For FSDP to work properly (and it is still experimental and incomplete!), you need to install lightning from source.

pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U

As for now, did someone fix the optimizer states broken problem on PyTorch lightning?

@awaelchli
Copy link
Member

Hi @villmow @wj210 @leng-yue
Here is an example code that produces a checkpoint that is not empty:

import os

import torch
from torch.utils.data import DataLoader, Dataset

from lightning.pytorch import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(1000, 1000)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(1000, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        strategy="fsdp",
        accelerator="cuda",
        devices=2,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, train_data)

    checkpoint = torch.load(trainer.checkpoint_callback.best_model_path)
    print(list(checkpoint["state_dict"].keys()))
    for param in checkpoint["state_dict"].values():
        print(param.shape)
        print(param.sum())


if __name__ == "__main__":
    run()

Can you please modify it and post the resulting code that produces the checkpoint that can't be loaded? I can help if I get a runnable minimal example back. Thanks!

@wj210
Copy link
Author

wj210 commented Jun 13, 2023

Hi @awaelchli , i tried the code above but it was giving error as fsdp requires to reference self.trainer.model?
ValueError: The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer after setting up the model by referencing self.trainer.model.parameters()in theconfigure_optimizers() hook.

But anyways, I changed layer to model and in configure_optimizer from self.layer to self.trainer.model, it still returns an empty dict

@leng-yue
Copy link
Contributor

I think this may relate to #17817.

@awaelchli
Copy link
Member

@wj210 Did you follow the instructions to install the latest version of Lightning #17798 (comment) and use the latest version of PyTorch?

@wj210
Copy link
Author

wj210 commented Jun 13, 2023

apologies, i was using a separate environment. Yes it does works in that example. the difference is in that trainer.checkpoint_callback.best_model_path is called while I was using modelcheckpoint? how can I save the checkpoint then in this case?

@leng-yue
Copy link
Contributor

This one should be fine since it only has one layer and fsdp won't shard it. You can check this example: https://github.com/Lightning-AI/lightning/blob/e80e467fa4e79edf34819ebcc8c46a3baf0a33af/tests/tests_pytorch/strategies/test_fsdp.py#L491

@wukevin
Copy link

wukevin commented Jul 23, 2023

I had this issue as well - I found that the following seems to work:

class MyModel(pl.LightningModule):
    
    ...
    
    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        """
        Tentative fix for FSDP checkpointing issue
        """
        if not checkpoint.get("state_dict", None):
            state_dict = self.trainer.model.state_dict()
            checkpoint["state_dict"] = state_dict
        return super().on_save_checkpoint(checkpoint)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested strategy: fsdp Fully Sharded Data Parallel ver: 2.0.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants