Skip to content

Commit

Permalink
Fix implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
milesial committed Jan 10, 2023
1 parent 183a6a6 commit 0406a55
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/lightning_fabric/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def optimizer_step(
previous_scale = self.scaler.get_scale()
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
step_output = self.scaler.step(optimizer, **kwargs)
model._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale
self.scaler.update()
optimizer._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale
return step_output

def state_dict(self) -> Dict[str, Any]:
Expand Down
8 changes: 5 additions & 3 deletions src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,8 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None:

# update non-plateau LR schedulers
# update epoch-interval ones only when we are at the end of training epoch
if not getattr(self.trainer.lightning_module, "_skip_next_scheduler_step", False):
self.update_lr_schedulers("step", update_plateau_schedulers=False)
elif self._num_ready_batches_reached():
self.update_lr_schedulers("step", update_plateau_schedulers=False)
if self._num_ready_batches_reached():
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)

batch_end_outputs = self._prepare_outputs_training_batch_end(
Expand Down Expand Up @@ -451,6 +450,9 @@ def _update_learning_rates(
)
continue

if getattr(self.trainer.optimizers[config.opt_idx], "_skip_next_scheduler_step", False):
continue

self.scheduler_progress.increment_ready()

# update LR
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@ def optimizer_step( # type: ignore[override]
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
previous_scale = self.scaler.get_scale()
step_output = self.scaler.step(optimizer, **kwargs)
self.scaler.update()
optimizer._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale
return step_output
return closure_result

Expand Down
1 change: 1 addition & 0 deletions tests/tests_fabric/plugins/precision/test_native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_native_amp_precision_backward():
def test_native_amp_precision_optimizer_step_with_scaler():
precision = MixedPrecision(precision="mixed", device="cuda")
precision.scaler = Mock()
precision.scaler.get_scale = Mock(return_value=1.0)
optimizer = Mock()

precision.optimizer_step(optimizer, keyword="arg")
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def _auto_train_batch(
trainer, model, batches, device=torch.device("cpu"), current_epoch=0, current_batch=0, **kwargs
):
using_deepspeed = kwargs.get("strategy") == "deepspeed"
using_native_amp = trainer.precision == 16 and trainer.amp_backend == 'native'
out = []
for i in range(current_batch, batches):
out.extend(
Expand Down Expand Up @@ -347,7 +348,7 @@ def _auto_train_batch(
kwargs=dict(on_tpu=False, using_lbfgs=False),
),
*(
[dict(name="lr_scheduler_step", args=(ANY, 0, None))]
[dict(name="lr_scheduler_step", args=(ANY, 0, None)) if not using_native_amp else ANY]
if i == (trainer.num_training_batches - 1)
else []
),
Expand Down

0 comments on commit 0406a55

Please sign in to comment.