Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errors when trying to load a VariationalBNN #15

Open
francescofolino opened this issue May 13, 2022 · 5 comments
Open

Errors when trying to load a VariationalBNN #15

francescofolino opened this issue May 13, 2022 · 5 comments

Comments

@francescofolino
Copy link

francescofolino commented May 13, 2022

Hi all,
I'm new to TyXe, but I'm experimenting an issue when I'm trying to load a (previously) trained model from the disk.

To be more precise, the returned error is as in the following:

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for VariationalBNN: Unexpected key(s) in state_dict: net_guide.rnn.weight_ih_l0.loc_unconstrained
etc.

In particular, to save the model, I use a code like this:

pyro.get_param_store().save(os.path.join(output_dir, "param_store.pt"))
torch.save(model.state_dict(), os.path.join(output_dir, "best_mode.pt"))

To load the model (defined as tyxe.VariationalBNN(net, prior, likelihood, guide)) instead:

pyro.clear_param_store()
model.load_state_dict(torch.load(os.path.join(save_model_path, "best_model.pt")))
pyro.get_param_store().load(os.path.join(save_model_path, "param_store.pt"))

Where is the error?

Thank you so much.

@hpplyt
Copy link
Collaborator

hpplyt commented May 27, 2022

Hi,

Apologies for the slow response! I think this is just due to the variational parameter attributes being initialized lazily by Pyro. If your (deterministic) network doesn't have any buffers, i.e. only parameters, you shouldn't need to save/load the state dict and the param store should contain everything you need. Otherwise, if you do need to load the state dict, just run a forward pass through your BNN by calling guide_forward with some valid input data to initialize the parameter attributes before loading.

Let me know if neither option resolves the error, in that case I'd need to take a closer look at what's going on :)

@Cam-B04
Copy link
Contributor

Cam-B04 commented Oct 19, 2022

Hi,

I have encountered the same error and I might have found a solution. You need to load the state dict using the .netattribute of your model :

pyro.clear_param_store()
model.net.load_state_dict(torch.load(os.path.join(save_model_path, "best_model.pt")))
pyro.get_param_store().load(os.path.join(save_model_path, "param_store.pt"))

Hope this helps !

@francescofolino
Copy link
Author

francescofolino commented Oct 19, 2022 via email

@Cam-B04
Copy link
Contributor

Cam-B04 commented Oct 19, 2022

Forgot in my answer that it is necessary to save it as well as following :
torch.save(model.net.state_dict(), os.path.join(output_dir, "best_mode.pt"))

@freakontrol
Copy link

Hi, I had the same issue and the solution of @Cam-B04 worked correctly, thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants