Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model restore fails from stored checkpoint when using Deepspeed #7282

Closed
gurvindersingh opened this issue Apr 29, 2021 · 4 comments 路 Fixed by #8397
Closed

Model restore fails from stored checkpoint when using Deepspeed #7282

gurvindersingh opened this issue Apr 29, 2021 · 4 comments 路 Fixed by #8397
Assignees
Labels
3rd party Related to a 3rd-party bug Something isn't working distributed Generic distributed-related topic help wanted Open to be worked on priority: 1 Medium priority task
Milestone

Comments

@gurvindersingh
Copy link

gurvindersingh commented Apr 29, 2021

馃悰 Bug

Trying to restore a checkpoint to resume training but it fails with the below exceptions

RuntimeError: Error(s) in loading state_dict for BoringModel:
        size mismatch for layer.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for layer.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
    self.lightning_module.load_state_dict(ckpt['state_dict'])
  File "/home/ca5b7a03-2d901b-2d45e5-2d969e-2df8ccc075972b/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for BoringModel:
        size mismatch for layer.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for layer.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for BoringModel:
        size mismatch for layer.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for layer.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for BoringModel:
        size mismatch for layer.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for layer.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).

To Reproduce

Run the following mode with commented out restore argument, then run it again with uncommenting the restore and you will see the exception.

import os
import torch
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import DeepSpeedPlugin
from deepspeed.ops.adam import FusedAdam


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    checkpoint_callback = ModelCheckpoint(
        dirpath='tests/',
        filename='{epoch:02d}',
    )
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        gpus=-1,
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        precision=16,
        accelerator='ddp',
        max_epochs=2,
        plugins=[DeepSpeedPlugin(cpu_offload=False, stage=3)],
        weights_summary=None,
        callbacks=[checkpoint_callback],
        #resume_from_checkpoint='tests/epoch=01.ckpt',
    )
    trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)
    trainer.test(model, test_dataloaders=test_data)


if __name__ == '__main__':
    run()

Expected behavior

Training resume successfully from stored checkpoint.

Environment

Tried with lightning version: 1.2.10, 1.3.0.rc1 and master
pytorch: 1.7.1
OS: Ubuntu 18.04

@SeanNaren As discussed on slack ^^

@gurvindersingh gurvindersingh added bug Something isn't working help wanted Open to be worked on labels Apr 29, 2021
@SeanNaren SeanNaren self-assigned this Apr 30, 2021
@SeanNaren SeanNaren added distributed Generic distributed-related topic 3rd party Related to a 3rd-party labels Apr 30, 2021
@SeanNaren
Copy link
Contributor

Hi @gurvindersingh I have a fix at: #7297 but the tests are failing (but the above reprod now works on multiple GPUs). Could you test this out and let me know?

@gurvindersingh
Copy link
Author

@SeanNaren model does get loaded fine now without error, but it seems the state is somehow corrupted as training loss was high upon resuming training from checkpoint. Loss was the same as network is randomly initialized not restored from trained checkpoint, so seems like there is some issue in correctly loading/storing the weights using deepspeed.

When I run my training (from scratch) without deepspeed and resume it from this newly stored checkpoint, training loss resumes from where it was when checkpoint was stored, so model was able to load the state successfully without deepspeed and resume training.

@SeanNaren
Copy link
Contributor

Thanks @gurvindersingh probably comes from the fact the optimizer states are not loaded, just the model itself. I'll investigate further as to what's required to restore the optimizer state and LR state!

@SeanNaren
Copy link
Contributor

SeanNaren commented Aug 3, 2021

We've merged a lot of fixes for DeepSpeed that should allow a checkpoint to be restored fully! This has required changing the default method of saving to fully rely on DeepSpeed (which saves a directory), and you can generate a single file for inference by following these instructions: https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#deepspeed-zero-stage-3-single-file. let us know if you run into any issues!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working distributed Generic distributed-related topic help wanted Open to be worked on priority: 1 Medium priority task
Projects
None yet
4 participants