Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix saving native AMP scaler state #1777

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -46,6 +46,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed lr key name in case of param groups in LearningRateLogger ([#1719](https://github.com/PyTorchLightning/pytorch-lightning/pull/1719))

- Fixed saving native AMP scaler state (introduced in [#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561))


## [0.7.5] - 2020-04-27

### Changed
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_io.py
Expand Up @@ -338,8 +338,8 @@ def dump_checkpoint(self):

checkpoint['state_dict'] = model.state_dict()

# restore native amp scaling
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
# save native amp scaling
if self.use_amp and self.use_native_amp:
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()

if hasattr(model, "hparams"):
Expand Down