From d120f97896d201fb8fe5ea6083b3234321db536d Mon Sep 17 00:00:00 2001 From: Fabio Natanael Kepler Date: Tue, 12 May 2020 02:38:37 +0100 Subject: [PATCH] Fix saving native AMP scaler state (#1777) Saving was introduced in #1561. --- CHANGELOG.md | 3 +++ pytorch_lightning/trainer/training_io.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a1ae50dccc58a..b2b0d8feb528e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed lr key name in case of param groups in LearningRateLogger ([#1719](https://github.com/PyTorchLightning/pytorch-lightning/pull/1719)) +- Fixed saving native AMP scaler state (introduced in [#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561)) + + ## [0.7.5] - 2020-04-27 ### Changed diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 4f474b761e94f..437a89e42470f 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -338,8 +338,8 @@ def dump_checkpoint(self): checkpoint['state_dict'] = model.state_dict() - # restore native amp scaling - if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: + # save native amp scaling + if self.use_amp and self.use_native_amp: checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() if hasattr(model, "hparams"):