Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoints cannot be loaded in non-pl env #2653

Closed
s-rog opened this issue Jul 21, 2020 · 9 comments 路 Fixed by #3287
Closed

Checkpoints cannot be loaded in non-pl env #2653

s-rog opened this issue Jul 21, 2020 · 9 comments 路 Fixed by #3287
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@s-rog
Copy link
Contributor

s-rog commented Jul 21, 2020

## 馃殌 Feature
Add an option to save only state_dict for ModelCheckpoint callbacks

馃悰 Bug

PL checkpoints cannot be loaded in non-pl envs

Motivation

To be able to move trained models and weights into pytorch only environments

Additional context

Currently when you do torch.load() on a pl generated checkpoint in an environment without pl, there is a pickling error. For my current use case I have to load the checkpoints in my training environment and save them again with only state_dict for the weights.

See reply below for more info

@s-rog s-rog added feature Is an improvement or enhancement help wanted Open to be worked on labels Jul 21, 2020
@rohitgr7
Copy link
Contributor

You can use save_weights_only parameter in ModelCheckpoint to save weights only. Although it will save epoch, global_step and pl_version but that won't be a problem there, I guess. Also can you show the pickling error you are getting?

@s-rog
Copy link
Contributor Author

s-rog commented Jul 22, 2020

I am using save_weights_only and that causes a pickling error with module lightning not found (don't have the extact error atm)

@rohitgr7
Copy link
Contributor

Can you check when you load that checkpoint manually in pl env, what keys does that file have?

@s-rog
Copy link
Contributor Author

s-rog commented Aug 27, 2020

Error in non-pl env

ModuleNotFoundError                       Traceback (most recent call last)
<ipython-input-10-dbc5018f5317> in <module>
----> 1 pretrained_dict = torch.load('../input/weights/test.ckpt', map_location=torch.device('cpu'))

/opt/conda/lib/python3.7/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
    591                     return torch.jit.load(f)
    592                 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
--> 593         return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
    594 
    595 

/opt/conda/lib/python3.7/site-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
    771     unpickler = pickle_module.Unpickler(f, **pickle_load_args)
    772     unpickler.persistent_load = persistent_load
--> 773     result = unpickler.load()
    774 
    775     deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)

ModuleNotFoundError: No module named 'pytorch_lightning'

keys in pl env
dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'hparams_name', 'hyper_parameters'])

@rohitgr7 sorry about the late reply, completely forgot about this issue

Edit:
found the issue, I'll look into a fix

for k, v in pretrained_dict.items():
    print(type(k), type(v))

<class 'str'> <class 'int'>
<class 'str'> <class 'int'>
<class 'str'> <class 'str'>
<class 'str'> <class 'collections.OrderedDict'>
<class 'str'> <class 'str'>
<class 'str'> pytorch_lightning.utilities.parsing.AttributeDict

Edit 2:
I'll submit a PR after refactor week!

@s-rog s-rog changed the title Option to save only 'state_dict' for checkpoints Checkpoints cannot be loaded in non-pl env Aug 27, 2020
@Borda Borda added bug Something isn't working and removed feature Is an improvement or enhancement labels Aug 27, 2020
@s-rog
Copy link
Contributor Author

s-rog commented Aug 31, 2020

I got around to testing and can load checkpoints now in non-pl envs. The only change needed was to cast hyper_parameters to dict in dump_checkpoint of pytorch_lightning/trainer/training_io.py

- checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
+ checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)

Thoughts?

@rohitgr7
Copy link
Contributor

Yeah this looks good to avoid such error since AttributeDict is a PL thing.

@rohitgr7
Copy link
Contributor

@s-rog , I tried on master with save_weights_only=True and these are the dict keys I got. No hyperparams.

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict'])

@s-rog
Copy link
Contributor Author

s-rog commented Sep 1, 2020

@rohitgr7 Did the model have self.hparams?

If you look at dump_checkpoint() the weights_only arg only controls:
callbacks, optimizer_states, lr_schedulers, native_amp_scaling_state and amp_scaling_state

hparams loggging is only controlled by if model.hparams:

@rohitgr7
Copy link
Contributor

rohitgr7 commented Sep 1, 2020

ok, yeah my bad :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants