Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support training resume and saving best model for SWA #6074

Closed
b02202050 opened this issue Feb 19, 2021 · 4 comments
Closed

Support training resume and saving best model for SWA #6074

b02202050 opened this issue Feb 19, 2021 · 4 comments
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Milestone

Comments

@b02202050
Copy link

馃殌 Feature

To support stop-resume and saving the model based on the best validation performance when using Stochastic weight averaging.

Motivation

If one does not reach the end of the training, the SWA model would not be saved to the checkpoint.
For example:

  1. One may want to save the best validation performance model before overfitting.
  2. The training becomes unstable to reach the final epoch.

I observe that the saved checkpoint is still the original model even if I use an SWA callback.

Pitch

  1. Use SWA model to run validation step instead of the original model.
    1. We may consider to Load SWA weight before validation and Restore original model weight after validation
  2. Save both original model weight and SWA model weight into checkpoint for resume training.
    1. With the concern of say one want to load the model for testing, he/she might expect the model be the SWA weight. Therefore, we can save the SWA weight to the checkpotint['state_dict'] and let the original model weight saved as a callback state.

Alternatives

Additional context

@Borda @MilesCranmer, I do not have much experience on coding neither am I familiar with PyTorch Lightning. would you help? Thx a lot. 馃槈

@b02202050 b02202050 added feature Is an improvement or enhancement help wanted Open to be worked on labels Feb 19, 2021
@Borda Borda added this to the 1.3 milestone Feb 19, 2021
@edenlightning edenlightning removed this from the v1.3 milestone Apr 27, 2021
@edenlightning edenlightning added this to the v1.4 milestone May 9, 2021
@edenlightning edenlightning modified the milestones: v1.4, v1.5 Jul 6, 2021
@adamreeve
Copy link
Contributor

adamreeve commented Oct 15, 2021

Hi, I'm also interested in this feature and have had a go at implementing it. I've opened a draft PR at #9938, and would appreciate some feedback from the PyTorch Lightning devs to make sure you're happy with this approach before I tidy it up and add tests and documentation.

I've made using SWA weights for validation optional and defaulted it to false for backwards compatibility. Loading the best model from a checkpoint seems a bit awkward though. With my changes you can use the resume_from_checkpoint argument when creating a Trainer to resume training, but if you use the load_from_checkpoint method of LightningModule then the model still contains the non-SWA parameters as it doesn't know anything about the SWA callback. I've added a classmethod StochasticWeightAveraging.restore_average_parameters_from_checkpoint that needs to be called separately after creating the module so any thoughts on how to make the API there nicer would be appreciated. For example, loading the best checkpoint currently looks a bit like this:

checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=3, mode='min')
swa_callback = StochasticWeightAveraging(swa_epoch_start=0.8, swa_validation=True)

trainer = Trainer(..., callbacks=[swa_callback, checkpoint_callback])
trainer.fit(model, data_model)

checkpoint_path = checkpoint_callback.best_model_path
new_model = MyModel.load_from_checkpoint(checkpoint_path=checkpoint_path)
parameters_loaded: bool = StochasticWeightAveraging.restore_average_parameters_from_checkpoint(new_model, checkpoint_path)

One remaining issue is that when resuming from training, the SWALR scheduler is recreated with _step_count = 0, so annealing will begin again if it previously completed. I'm not sure how best to handle that. I tried calling the step method multiple times to step the scheduler forward to where it was previously, but that logs a warning: "Detected call of lr_scheduler.step() before optimizer.step()". I think in this scenario that warning would be safe to ignore but users would be worried something is wrong if they saw that. I could just save and restore the _step_count state but that's a private property on a class within PyTorch so that doesn't seem like a great idea.

I also need to work out how to handle models with batch normalization.

@adamreeve
Copy link
Contributor

adamreeve commented Oct 18, 2021

I've been looking into how batch normalization currently works when using SWA to better understand how this should be handled during validation. The mean and variance of inputs to the BatchNorm layers depend on the weights of lower layers so the running estimates computed by the underlying model won't be accurate when the weights are replaced with average weights. Therefore the mean and variance are computed with a full pass over the training set at the end of training, after replacing the weights with the averaged weights. This is done by increasing the number of epochs by one but setting _skip_backward to True for the last epoch to skip the optimization step, and modifying the batch norm layers to reset their mean and variance tensors before this final epoch and setting momentum to None so that they compute a simple cumulative moving average.

The validation passes should probably also update the batch norm parameters in order to accurately represent the model performance. Using the same approach of adding extra non-optimization training epochs before validation seems like it could get quite awkward though. I'm wondering whether it would be simpler to do this without modifying the training loop parameters by just using the training data fetcher directly. I've updated my draft PR to use this approach but I'm not that familiar with the PyTorch Lightning code base so am not sure what I might be missing by using the data fetcher directly rather than letting the training loop handle this.

The StochasticWeightAveraging.restore_average_parameters_from_checkpoint method will also need to be updated to take a data module or training data loader so that it can recompute batch norm parameters.

@adamreeve
Copy link
Contributor

Hi @tchaton, it looks like you implemented SWA in PyTorch Lightning originally, do you have any thoughts on this?

@awaelchli awaelchli modified the milestones: v1.5, v1.6 Nov 4, 2021
@Borda Borda modified the milestones: 1.6, 1.6.x Mar 21, 2022
@carmocca carmocca modified the milestones: pl:1.6.x, pl:1.8 Jul 28, 2022
@carmocca
Copy link
Contributor

Implemented in #9938

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

6 participants