Skip to content

Commit

Permalink
Fix implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
milesial committed Jan 10, 2023
1 parent 17a48b3 commit 4e71c28
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/lightning_fabric/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def optimizer_step(
previous_scale = self.scaler.get_scale()
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
step_output = self.scaler.step(optimizer, **kwargs)
model._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale
self.scaler.update()
optimizer._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale
return step_output

def state_dict(self) -> Dict[str, Any]:
Expand Down
8 changes: 5 additions & 3 deletions src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,8 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None:

# update non-plateau LR schedulers
# update epoch-interval ones only when we are at the end of training epoch
if not getattr(self.trainer.lightning_module, "_skip_next_scheduler_step", False):
self.update_lr_schedulers("step", update_plateau_schedulers=False)
elif self._num_ready_batches_reached():
self.update_lr_schedulers("step", update_plateau_schedulers=False)
if self._num_ready_batches_reached():
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)

batch_end_outputs = self._prepare_outputs_training_batch_end(
Expand Down Expand Up @@ -451,6 +450,9 @@ def _update_learning_rates(
)
continue

if getattr(self.trainer.optimizers[config.opt_idx], "_skip_next_scheduler_step", False):
continue

self.scheduler_progress.increment_ready()

# update LR
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@ def optimizer_step( # type: ignore[override]
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
previous_scale = self.scaler.get_scale()
step_output = self.scaler.step(optimizer, **kwargs)
self.scaler.update()
optimizer._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale
return step_output
return closure_result

Expand Down
1 change: 1 addition & 0 deletions tests/tests_fabric/plugins/precision/test_native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_native_amp_precision_backward():
def test_native_amp_precision_optimizer_step_with_scaler():
precision = MixedPrecision(precision="mixed", device="cuda")
precision.scaler = Mock()
precision.scaler.get_scale = Mock(return_value=1.0)
optimizer = Mock()

precision.optimizer_step(optimizer, keyword="arg")
Expand Down

0 comments on commit 4e71c28

Please sign in to comment.