-
Notifications
You must be signed in to change notification settings - Fork 3.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proper way to checkpoint model using FSDP #17798
Comments
Hi @wj210 I suggest you don't do that, since Lightning can already do it for you. With |
One more remark:
|
Hi, I have the same problem. I train a model with Training works fine and checkpoints are saved, but I can't load the checkpoints. The checkpoints seem to have the correct size, but the I just installed lightning from source and use torch 2.1.0.dev20230327 . Do you have any suggestions? |
My issue is similar to @villmow. What I meant by torch.save was an alternative I tried after discovering that the automatically checkpointing produces a checkpoint with empty state dict. I have tried to install from source but the response is the same. |
As for now, did someone fix the optimizer states broken problem on PyTorch lightning? |
Hi @villmow @wj210 @leng-yue import os
import torch
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(1000, 1000)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(1000, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
strategy="fsdp",
accelerator="cuda",
devices=2,
max_epochs=1,
enable_model_summary=False,
)
trainer.fit(model, train_data)
checkpoint = torch.load(trainer.checkpoint_callback.best_model_path)
print(list(checkpoint["state_dict"].keys()))
for param in checkpoint["state_dict"].values():
print(param.shape)
print(param.sum())
if __name__ == "__main__":
run() Can you please modify it and post the resulting code that produces the checkpoint that can't be loaded? I can help if I get a runnable minimal example back. Thanks! |
Hi @awaelchli , i tried the code above but it was giving error as fsdp requires to reference self.trainer.model? But anyways, I changed layer to model and in |
I think this may relate to #17817. |
@wj210 Did you follow the instructions to install the latest version of Lightning #17798 (comment) and use the latest version of PyTorch? |
apologies, i was using a separate environment. Yes it does works in that example. the difference is in that trainer.checkpoint_callback.best_model_path is called while I was using modelcheckpoint? how can I save the checkpoint then in this case? |
This one should be fine since it only has one layer and fsdp won't shard it. You can check this example: https://github.com/Lightning-AI/lightning/blob/e80e467fa4e79edf34819ebcc8c46a3baf0a33af/tests/tests_pytorch/strategies/test_fsdp.py#L491 |
I had this issue as well - I found that the following seems to work: class MyModel(pl.LightningModule):
...
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""
Tentative fix for FSDP checkpointing issue
"""
if not checkpoint.get("state_dict", None):
state_dict = self.trainer.model.state_dict()
checkpoint["state_dict"] = state_dict
return super().on_save_checkpoint(checkpoint) |
Bug description
I am trying to save the state_dict of a fine-tuned t5 model from huggingface, however i was unable to load from the checkpoint and checked that the state_dict in the checkpoint was an empty dict.
What version are you seeing the problem on?
v2.0
How to reproduce the bug
Error messages and logs
Environment
Current environment
More info
How can I properly save the model using fsdp? I tried to save it manually using torch.save(self.model.state_dict) and it was empty as well.
However when I am not using distributed training, it does not have the same issue.
cc @awaelchli @carmocca
The text was updated successfully, but these errors were encountered: