Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cant reload from checkpoint when using SWA #11665

Closed
ma-batita opened this issue Jan 31, 2022 · 18 comments 路 Fixed by #9938
Closed

Cant reload from checkpoint when using SWA #11665

ma-batita opened this issue Jan 31, 2022 · 18 comments 路 Fixed by #9938
Assignees
Labels
bug Something isn't working callback: swa priority: 1 Medium priority task

Comments

@ma-batita
Copy link

ma-batita commented Jan 31, 2022

馃悰 Bug

My model worked just fine until I tried some optimisation using SWA.

from pytorch_lightning.callbacks import  StochasticWeightAveraging

weighting = StochasticWeightAveraging()

The problem is not even clear to understand :

KeyError                                  Traceback (most recent call last)
<ipython-input-20-2d36fa4eaad0> in <module>()
     16 
     17 
---> 18 trainer.fit(module, data_module, ckpt_path="./checkpoints/best-checkpoint.ckpt")
     19 
     20 wandb.finish()

7 frames
/usr/local/lib/python3.7/dist-packages/torch/optim/lr_scheduler.py in load_state_dict(self, state_dict)
    233         """
    234 
--> 235         lr_lambdas = state_dict.pop('lr_lambdas')
    236         self.__dict__.update(state_dict)
    237         # Restore state_dict keys in order to prevent side effects

KeyError: 'lr_lambdas'

To Reproduce

https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/pl_examples/bug_report/bug_report_model.ipynb

Expected behavior

Run from checkpoint with SWA.

Environment

  • CUDA:
    • GPU:
      • Tesla V100-SXM2-16GB
    • available: True
    • version: 11.1
  • Packages:
    • numpy: 1.19.5
    • pyTorch_debug: False
    • pyTorch_version: 1.10.0+cu111
    • pytorch-lightning: 1.5.9
    • tqdm: 4.62.3
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.7.12
    • version: Proposal for help聽#1 SMP Tue Dec 7 09:58:10 PST 2021

cc @tchaton @rohitgr7 @akihironitta @carmocca

@ma-batita ma-batita added the bug Something isn't working label Jan 31, 2022
@tchaton tchaton added priority: 0 High priority task priority: 1 Medium priority task labels Jan 31, 2022
@myxik
Copy link
Contributor

myxik commented Feb 1, 2022

Hi! Can I take this issue?

@rohitgr7
Copy link
Contributor

rohitgr7 commented Feb 1, 2022

hey @BttMA can you update the reproducible colab link? currently it points to the one in the repo which doesn't have any of your update code.

@ma-batita
Copy link
Author

@rohitgr7
Copy link
Contributor

rohitgr7 commented Feb 1, 2022

hey @BttMA !
can you share an actual failing script?

I tried updating your example with:

def run(max_epochs, ckpt_path=None):
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=max_epochs,
        enable_model_summary=False,
        callbacks = weighting ############
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data, ckpt_path=ckpt_path)
    return trainer

trainer = run(max_epochs=5, ckpt_path=None)
trainer.save_checkpoint('best_checkpoint.ckpt')
ckpt_path = 'lightning_logs/version_0/checkpoints/epoch=4-step=4.ckpt'
trainer = run(max_epochs=20, ckpt_path=ckpt_path)

and it worked fine... so I guess I am unable to reproduce your issue.

@ma-batita
Copy link
Author

Sorry me neither I couldn't share the bug with you using the boring model. is any either way to share it ?

@rohitgr7
Copy link
Contributor

rohitgr7 commented Feb 1, 2022

share the notebook/script that is failing. You can attach it here too, in case someone else want to look at it.

@ma-batita
Copy link
Author

ma-batita commented Feb 1, 2022

it has personal/confidential data :/ I cant share it with everybody :o
Sorry is there any other way to share it with you ?

@rohitgr7
Copy link
Contributor

rohitgr7 commented Feb 1, 2022

maybe you can mimic the data. for starters, return random tensors of the same shape as your original data returns and use a small model.

@ma-batita
Copy link
Author

hi @rohitgr7 :)
You can see the fail with this random generation :) hope you manage to solve it soon :)

https://colab.research.google.com/drive/1JHaHvQ5PhfaYil0HnIkEOcF1MoindYcK?usp=sharing#scrollTo=sNm4IkAefdkL&uniqifier=1

PS : maybe it has something to do with "UserWarning: SWA is currently only supported every epoch." or "Swapping scheduler LambdaLR for SWALR" ? but they are just warnings and cant lead to a fail loading!

@rohitgr7
Copy link
Contributor

rohitgr7 commented Feb 2, 2022

yes! that might be the case.. We need to save and load the states for this callback to enable proper resuming.

@ma-batita
Copy link
Author

What callback? Sorry didn't get you.
as you can see in the dummy code, when I get the "UserWarning: SWA is currently only supported every epoch." the program skip it right away and continue other epoch? also in my real version of code I get the same warning after like 10 to 15 epochs! it does not make any sense 馃樀

Maybe the SWA does not support many epochs 馃 ?? even in this doc it is not very clear about the epoch and the SWA. We have to dig deep into this! 馃槣

@rohitgr7
Copy link
Contributor

rohitgr7 commented Feb 2, 2022

I am talking about StochasticWeightAveraging

the warning isn't reliable. Should be improved. I only got to know what it means by looking at the code.

UserWarning: SWA is currently only supported every epoch.

ideally it means if you are configuring interval='step' or frequency>1 inside scheduler configuration, it should work as per configuration expectation.
as you have done here:

return dict(lr_scheduler=dict(scheduler=scheduler, interval='step'),
                    optimizer=optimizer)

also in my real version of code I get the same warning after like 10 to 15 epochs! it does not make any sense 馃樀also in my real version of code I get the same warning after like 10 to 15 epochs! it does not make any sense 馃樀

check out the default parameters. by default it starts at when epoch=0.8*max_epochs.

@ma-batita
Copy link
Author

check out the default parameters. by default it starts at when epoch=0.8*max_epochs.

now I see! for exemple if I have 100 epochs then the SWA callback will be activated at the 40th epoch (since 40=0.8*100).

In my case, the SWA callback is just skipped because of the scheduler LambdaLR. It is never executed and when I try to load a checkpoint it gets messy and I got that error ? correct me if I am wrong please ?

@rohitgr7
Copy link
Contributor

rohitgr7 commented Feb 2, 2022

40th epoch (since 40=0.8*100)

80th epoch.

In your example, during the first run, it switch to SWALR at 40th epoch and saved the checkpoint at 50th epoch with SWALR state_dict. But when you reloaded the checkpoint, the trainer loaded them with LAmbda LR configured. Something like LambdaLR is trying to load the state_dict of SWALR, which is causing this error.

@ma-batita
Copy link
Author

80th epoch.

OH!! sure yes yes!!

In your example, during the first run, it switch to SWALR at 40th epoch and saved the checkpoint at 50th epoch with SWALR state_dict. But when you reloaded the checkpoint, the trainer loaded them with LAmbda LR configured. Something like LambdaLR is trying to load the state_dict of SWALR, which is causing this error.

Now it makes sense :) Thank a lot!

Can you suggest any thing for me to fix this, please?
Should I change my scheduler in the plModel from LambdaLR for SWALR ? something like this :

swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)

@rohitgr7
Copy link
Contributor

rohitgr7 commented Feb 2, 2022

Should I change my scheduler in the plModel from LambdaLR for SWALR ? something like this :

I'm not sure if this will work. I am not super familiar with every detail for SWA but I don't think that replacing the scheduler is all that's required to perform SWA. There's a lot more happening inside the callback. For the fix, I think we need to create states for this callback that can be stored and reloaded from the checkpoint while resuming the training. Need to investigate what all is required to make this work.

@ma-batita
Copy link
Author

For the fix, I think we need to create states for this callback that can be stored and reloaded from the checkpoint while resuming the training.

Actually I was going to suggest that but I don't know what held me 馃槄
I will keep the issue open for further investigation (it will be helpful if you could mention other members.)

thanks a lot!

@carmocca
Copy link
Contributor

carmocca commented Feb 4, 2022

For the fix, I think we need to create states for this callback that can be stored and reloaded from the checkpoint while resuming the training

This is correct. Saving and loading is not implemented.

Should I change my scheduler in the plModel from LambdaLR for SWALR?

This is done by the callback automatically.

@rohitgr7 rohitgr7 self-assigned this Feb 4, 2022
@Borda Borda removed the priority: 0 High priority task label Aug 8, 2022
@rohitgr7 rohitgr7 linked a pull request Aug 8, 2022 that will close this issue
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working callback: swa priority: 1 Medium priority task
Projects
No open projects
Status: Done
Development

Successfully merging a pull request may close this issue.

6 participants