Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow training to be resumed with a different learning rate (or other overridden hyperparameters) #12118

Closed
rubvber opened this issue Feb 25, 2022 Discussed in #6494 · 3 comments
Labels
checkpointing Related to checkpointing question Further information is requested

Comments

@rubvber
Copy link

rubvber commented Feb 25, 2022

Discussed in #6494

Originally posted by MightyChaos March 12, 2021
I want to resume training from a checkpoint, but I want to use a different learning rate, How to achieve that? I don't really care about the training states and don't mind start a fresh training as long as the weights are proprely restored.

Right now I'm using resume_from_checkpoint=ckpt_file when creating the trainer, this automatically would give the old learning rate.

I also tried remove resume_from_checkpoint=ckpt_file, and do

net_learner.load_from_checkpoint(cfg.ckpt_path, cfg=cfg)
trainer.fit(net_learner, train_data_loader, val_data_loader)

but it seems the weights are erased, and the trainer starts from random weights.

Any help will be most appreciated, thanks so much!

Feature

Provide the ability to resume training a model with a different learning rate (scheduler). Currently, it seems it is only possible within the Lightning framework to resume training from a complete snapshot of a previous state, including not just the model weights and other parameters, but also the optimizer state and any hyperparameters that are set at initialization.

In parallel, it is possible to load a model from a saved checkpoint, and override any of its hyperparameters for testing or deployment purposes.

However, a combination of the two does not seem to be possible (but I'd love to be corrected on this). As indicated by the post quoted above, if a model is loaded using .load_from_checkpoint(), and then passed to trainer.fit(), training does not seem to resume with the saved model weights but with a new initialization. I'm not sure if this is a bug or just the lack of a feature.

Motivation

It is quite common (in my experience), while experimenting with different learning rate (schedules) for a model, that you want to halt training and resume it with a new learning rate. E.g. maybe the LR scheduler decreased the LR but you want to see what happens if you increase it again, or perhaps continue with a cyclic learning rate. Or maybe you'd like to reduce the learning rate as the loss has plateaud and you didn't account for this in your LR schedule(r).

More broadly, the model may include hyperparameters that you may want to tweak during a training run, for whatever reason. Maybe you think the model could continue to improve by changing a setting, without having to start from scratch. Again, this is currently possible when using .load_from_checkpoint(), but doesn't seem to be an option when resuming training within the Lightning ecosystem.

Currently, the only workaround for this issue seems to be to load the model using .load_from_checkpoint() (overriding any hyperparameters as desired) and then resume training in a bespoke training script (with an optimizer whose learning rate you can control). But again, I'd love to hear if there's a better way within Lightning - if not then I'd like to propose it as a new feature.

cc @awaelchli @ananthsub @ninginthecloud @rohitgr7

@awaelchli
Copy link
Member

However, a combination of the two does not seem to be possible (but I'd love to be corrected on this). As indicated by the post quoted above, if a model is loaded using .load_from_checkpoint(), and then passed to trainer.fit(), training does not seem to resume with the saved model weights but with a new initialization. I'm not sure if this is a bug or just the lack of a feature.

That's because of incorrect usage:

net_learner.load_from_checkpoint(cfg.ckpt_path, cfg=cfg)
trainer.fit(net_learner, train_data_loader, val_data_loader)

It should be:

net_learner = YourModelClass.load_from_checkpoint(cfg.ckpt_path, cfg=cfg)
trainer.fit(net_learner, train_data_loader, val_data_loader)

Emphasis here is on YourModelClass.

More broadly, the model may include hyperparameters that you may want to tweak during a training run, for whatever reason. Maybe you think the model could continue to improve by changing a setting, without having to start from scratch. Again, this is currently possible when using .load_from_checkpoint(), but doesn't seem to be an option when resuming training within the Lightning ecosystem

A related feature request exists in #5339

You can change the learning rate by making it a hyperparameter to the LightningModule, and then set it when you load it:

net_learner = YourModelClass.load_from_checkpoint(cfg.ckpt_path, cfg=cfg, learning_rate=0.25)  # <--- here

(make sure to use save_hyperparameters())

@awaelchli awaelchli added the question Further information is requested label Feb 26, 2022
@rubvber
Copy link
Author

rubvber commented Feb 27, 2022

Thanks for your response! I now see that the post I quoted (not mine) does indeed use the load_from_checkpoint() method incorrectly. My own usage conforms to your example but seemed to show the same behavior, as the loss immediately blew up after resuming training. However, I must admit I did not specifically check whether the weights were being overwritten - I just put 2 and 2 together with my own experience and that of OP (MightyChaos). That's on me - I should have verified that properly.

Now that I try it again it actually seems to work as expected: weights are loaded correctly and this time the loss doesn't blow up. I'm a bit confused as to what happened previously then. Maybe the new learning rate I set last time was just so big that it immediately diverged, or maybe it was just an unrelated bug or user error on my part. Anyway, sorry to waste your time and thanks again for taking the time to respond!

@awaelchli
Copy link
Member

Anyway, sorry to waste your time and thanks again for taking the time to respond!

No, not at all, don't worry about it. Happy to help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpointing Related to checkpointing question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants