Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixed precision: scheduler and optimizer are called in the wrong order #5558

Open
manifoldhiker opened this issue Jan 18, 2021 · 36 comments
Open
Labels
bug Something isn't working lr scheduler precision: amp Automatic Mixed Precision priority: 2 Low priority task
Milestone

Comments

@manifoldhiker
Copy link

manifoldhiker commented Jan 18, 2021

🐛 Bug

When using mixed-precision training, scheduler and optimizer are called in the wrong order. Warning is generated:

UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.

Please reproduce using the BoringModel

https://colab.research.google.com/drive/1G7pk6E9XUYq-pS41DXKhqM9Srx8sikiP?usp=sharing

There are four tests. Three of them doesn't raise the warning:

  1. test_amp_scheduler(precision=16, configure_optimizers=configure_optimizers_1)
  2. test_amp_scheduler(precision=32, configure_optimizers=configure_optimizers_1)
  3. test_amp_scheduler(precision=32, configure_optimizers=configure_optimizers_2)

This testcase raises the warning:

  1. test_amp_scheduler(precision=16, configure_optimizers=configure_optimizers_2)

To Reproduce

  1. Create model with configure_optimizers in a following dictionary style:
def configure_optimizers_2(model):
    optimizer = torch.optim.SGD(model.layer.parameters(), lr=0.1)
    scheduler = {'scheduler':  torch.optim.lr_scheduler.StepLR(optimizer, step_size=1),
              'name': 'learning_rate',
              'interval':'step',
              'frequency': 1}
    
    return {"optimizer": optimizer, "lr_scheduler": scheduler}
  1. Enable mixed-precision training by setting precision=16 in a Trainer
  2. Start training

Note

When scheduler is defined in another way, the issue seems to not occur:

def configure_optimizers_1(model):
    optimizer = torch.optim.SGD(model.layer.parameters(), lr=0.1)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
    
    return {"optimizer": optimizer, "lr_scheduler": scheduler}

Expected behavior

No warning

Environment

  • CUDA:
    • GPU:
      • Tesla P100-PCIE-16GB
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.19.5
    • pyTorch_debug: True
    • pyTorch_version: 1.7.0+cu101
    • pytorch-lightning: 1.1.4
    • tqdm: 4.41.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.6.9
    • version: Proposal for help #1 SMP Thu Jul 23 08:00:38 PDT 2020

cc @tchaton @rohitgr7 @carmocca @justusschock @awaelchli @akihironitta

@manifoldhiker manifoldhiker added bug Something isn't working help wanted Open to be worked on labels Jan 18, 2021
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@edenlightning edenlightning added priority: 1 Medium priority task good first issue Good for newcomers labels Jan 19, 2021
@javierlorenzod
Copy link

Looking at the warning message, it seems that this is a problem related to the precision. As it is explained in documentation, if 16-bit precision is used, optimization is automatically managed by PyTorch Lightning. From versions >= 1.1.0 in PyTorch, Detected call of lr_scheduler.step() before optimizer.step(). I do not know how to follow the trace on colab, when I figure it out, I will search for the origin of this call. It seems that for 16-bit precision, the order of the calls is different in this scheduler creation procedure.

@stachu86
Copy link

I'm getting the same warning when ddp_sharded is turned on. My optimizer is defined similarly to configure_optimizers_1

javierlorenzod added a commit to javierlorenzod/pytorch-lightning that referenced this issue Feb 16, 2021
javierlorenzod added a commit to javierlorenzod/pytorch-lightning that referenced this issue Feb 16, 2021
@stale stale bot added the won't fix This will not be worked on label Mar 19, 2021
@stale stale bot closed this as completed Mar 29, 2021
@sanxing-chen
Copy link

Same issue.

@griff4692
Copy link

I am getting the same issue still as well

@akihironitta akihironitta reopened this May 12, 2021
@stale stale bot removed the won't fix This will not be worked on label May 12, 2021
@akihironitta
Copy link
Contributor

@griff4692 @sanxing-chen Hi, thank you for your report. Which version are you using? Could you try with the latest version of pytorch-lightning? pip install pytorch-lightning -U

@javierlorenzod
Copy link

@akihironitta I have run again the colab from the beginning of the issue and the warning problem is still there. This is the environment printed by the collecting script:

* CUDA:
	- GPU:
		- Tesla P100-PCIE-16GB
	- available:         True
	- version:           10.1
* Packages:
	- numpy:             1.19.5
	- pyTorch_debug:     False
	- pyTorch_version:   1.8.1+cu101
	- pytorch-lightning: 1.3.1
	- tqdm:              4.41.1
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- 
	- processor:         x86_64
	- python:            3.7.10
	- version:           #1 SMP Tue Apr 20 19:55:43 PDT 2021

@akihironitta
Copy link
Contributor

@javierlorenzod Thanks a lot for your report! Let me look into it.

@benihime91
Copy link

I am using pytorch-lightning==1.3.3, problems seems to exist here as well ...

@cpk26
Copy link

cpk26 commented Jun 7, 2021

As another datapoint, I'm finding this issue with pytorch-lightning==1.3.4

@stale stale bot added the won't fix This will not be worked on label Jul 7, 2021
@Lightning-AI Lightning-AI deleted a comment from stale bot Jul 8, 2021
@stale stale bot removed the won't fix This will not be worked on label Jul 8, 2021
@xfffrank
Copy link

xfffrank commented Jul 26, 2021

Same issue here with pytorch-lightning==1.3.1.

@stale stale bot added the won't fix This will not be worked on label Aug 25, 2021
@Lightning-AI Lightning-AI deleted a comment from stale bot Aug 30, 2021
@stale stale bot removed the won't fix This will not be worked on label Aug 30, 2021
@Lightning-AI Lightning-AI deleted a comment from stale bot Aug 30, 2021
@jstremme
Copy link

Same issue with pytorch-lightning==1.4.1.

@carmocca carmocca modified the milestones: future, 1.5.x Feb 3, 2022
@carmocca carmocca added priority: 2 Low priority task and removed priority: 1 Medium priority task labels Mar 1, 2022
@Borda Borda modified the milestones: 1.5.x, 1.6.x Mar 21, 2022
@alimoezzi
Copy link

alimoezzi commented Jun 2, 2022

I'm using PL pytorch-lightning==1.6.4 but still same issue

@akihironitta
Copy link
Contributor

A quick around is to override LightningModule.lr_scheduler_step() (only with PL 1.6.0 or later) so that it skips lr_scheduler.step() whenever the scaler skips optimizer.step(). For multiple single optimizers, it needs some change, but for a single optimizer, the following should work:

class YourLightningModule(LightningModule):
    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **kwargs):
        self.should_skip_lr_scheduler_step = False
        scaler = getattr(self.trainer.strategy.precision_plugin, "scaler", None)
        if scaler:
            scale_before_step = scaler.get_scale()
        optimizer.step(closure=optimizer_closure)
        if scaler:
            scale_after_step = scaler.get_scale()
            self.should_skip_lr_scheduler_step = scale_before_step > scale_after_step

    def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
        if self.should_skip_lr_scheduler_step:
            return
        scheduler.step()

See here for a complete script using BoringModel: https://github.com/akihironitta/gist/blob/repro/5558-amp-scheduler-workaround/pl_boring_model/main.py

@akihironitta akihironitta added the precision: amp Automatic Mixed Precision label Jul 21, 2022
@collinmccarthy
Copy link

collinmccarthy commented Jul 26, 2022

I'm not using PTL right now but I'm interested in the "right" solution here. The issue has nothing to do with PTL like other people have said.

@akihironitta A couple of comments / questions.

  • I believe scaler.get_scale() will simply return None if optimizer.step() was never called due to NaN/inf (see here)
  • I'm torn as to whether it's better to simply call scheduler.step() every time and just (try) to catch/squash the warning. Maybe that doesn't work well for PTL, but if I'm using a LR schedule I expect it to be followed regardless of whether or not 16-bit precision errors are inhibiting grad updates a few iterations. I think I'd rather just stick to the schedule and update it every time. It's not like I'm re-doing the batch if the scaling produced NaNs, I'm just moving onto the next batch. Again, I'm torn as to the "right" approach, but in the end it probably doesn't matter in terms of the final trained weights.

Cheers,
-Collin

Edit: I couldn't successfully suppress the warning, ended up comparing to None and skipping

Edit2: Testing for None as a return value doesn't work for all optimizers, e.g. AdamW without a closure will return None even when stepped. So testing the scale before and after seems like the best way.

@carmocca carmocca modified the milestones: pl:1.6.x, pl:future Jul 28, 2022
@akihironitta
Copy link
Contributor

Hi @collinmccarthy, thank you for your comment.

The issue has nothing to do with PTL like other people have said.

Yes, as I commented a while ago #5558 (comment), this issue stems from how amp is implemented.

if I’m using a LR schedule I expect it to be followed regardless of whether or not 16-bit precision errors are inhibiting grad updates a few iterations. I think I’d rather just stick to the schedule and update it every time. It’s not like I’m re-doing the batch if the scaling produced NaNs, I’m just moving onto the next batch

That’s totally fine if you’re fine with it. However, some people might still prefer to use the hack above to avoid excessive lr_scheduler.step() calls, and that's why I left the code snippet above. If you’re fine with calling lr_scheduler.step() excessively, you can just ignore the warning. If you find it too noisy, you can suppress the warning with:

import warnings
warnings.filterwarnings(“ignore”, "Detected call of", UserWarning)

https://docs.python.org/3/library/warnings.html#warnings.filterwarnings

@YooSungHyun
Copy link

@akihironitta hi! i'm using YourLightningModule code,
but, some epoch, get this error

ValueError: Tried to step 42552 times. The specified number of total steps is 42550
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/data/asr_proj/stt/RNNTransducer/.venv/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/data/asr_proj/stt/RNNTransducer/.venv/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 220, in advance
    self.update_lr_schedulers("step", update_plateau_schedulers=False)
  File "/data/asr_proj/stt/RNNTransducer/.venv/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 397, in update_lr_schedulers
    self._update_learning_rates(
  File "/data/asr_proj/stt/RNNTransducer/.venv/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 458, in _update_learning_rates
    self.trainer._call_lightning_module_hook(
  File "/data/asr_proj/stt/RNNTransducer/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1305, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/data/asr_proj/stt/RNNTransducer/model.py", line 200, in lr_scheduler_step
    scheduler.step()
  File "/data/asr_proj/stt/RNNTransducer/.venv/lib/python3.9/site-packages/torch/optim/lr_scheduler.py", line 161, in step
    values = self.get_lr()
  File "/data/asr_proj/stt/RNNTransducer/.venv/lib/python3.9/site-packages/torch/optim/lr_scheduler.py", line 1686, in get_lr
    raise ValueError("Tried to step {} times. The specified number of total steps is {}"
ValueError: Tried to step 42552 times. The specified number of total steps is 42550

my code like this

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **kwargs):
        self.should_skip_lr_scheduler_step = False
        scaler = getattr(self.trainer.strategy.precision_plugin, "scaler", None)
        if scaler:
            scale_before_step = scaler.get_scale()
        optimizer.step(closure=optimizer_closure)
        if scaler:
            scale_after_step = scaler.get_scale()
            self.should_skip_lr_scheduler_step = scale_before_step > scale_after_step

    def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
        if self.should_skip_lr_scheduler_step:
            return
        scheduler.step()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            [{"params": [p for p in self.parameters()], "name": "OneCycleLR"}],
            lr=self.args.learning_rate,
            weight_decay=self.args.weight_decay,
        )
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.args.max_lr,
            steps_per_epoch=self.steps_per_epoch,
            epochs=self.trainer.max_epochs,
            pct_start=0.05,
        )
        lr_scheduler = {"interval": "step", "scheduler": scheduler, "name": "AdamW"}
        return [optimizer], [lr_scheduler]

@YooSungHyun
Copy link

YooSungHyun commented Nov 21, 2022

@collinmccarthy
totally agree...
i'm agree with you and i will testing 32fp OneCycleLR and 16fp OneCycleLR except warning.
when i managed learning_rate in my self, i am faced much more error or side effect 😂. optimize is so hard to me 😣

@YooSungHyun
Copy link

YooSungHyun commented Nov 23, 2022

In my case, warning is not important.
i logged loss, lr

  1. fp16_warning except: pink
  2. fp32: purple
  3. fp16_lr_step_override: brown
    image

3 case is diffrent value each other, but very very very small.
so, i don't mind printing warning now.
cuda 11.4
python 3.9
torch-lightning 1.8.1
torch 1.13.0

@milesial
Copy link

milesial commented Jan 3, 2023

I propose a fix in #16229 . The issue is not on the PyT side, it's on PTL side.

When using LR scheduler for each step together with AMP, the PyT user (PTL) should check that the optimizer step wasn't skipped by the grad scaler before stepping the scheduler.

In this PR I use the same check that PyT uses to generate that warning optimizer._step_count.

@awaelchli
Copy link
Member

I think we should implement pytorch/pytorch#67590 (PyTorch). Any additions in Lightning would always be workarounds.

@ayansengupta17
Copy link

Following and waiting.

@morestart
Copy link

any update?

@yipliu
Copy link

yipliu commented Oct 14, 2023

pytorch==2.1.0

pytorch-lightning==2.1.0

@oguz-hanoglu
Copy link

In my case, the warning is raised during the first four steps, while an epoch consists of 500+ steps. Since the warning occurs in the first step, I also receive the "scheduler called before optimizer is called" warning. I like to address these warnings not only because they are annoying and can lead others in the project to assume there is a significant problem, but also because there is no guarantee that the skipped optimizer steps will always be limited.

I have noticed that my optimizer (AdamW) has _step_count in it. After debugging it, I observed that the count is not increased during skipped steps. Therefore, another possible workaround would be:

...
   self.scheduler_step_counter = 0
...
    def lr_scheduler_step(self, scheduler, metric):
        if self.scheduler_step_counter < scheduler.optimizer._step_count:
            super().lr_scheduler_step(scheduler, metric)
            self.scheduler_step_counter += 1
            assert (
                self.scheduler_step_counter == scheduler.optimizer._step_count
            ), "scheduler_step_counter should be equal to optimizer._step_count"

@huuquan1994
Copy link

pytorch==2.2.2
lightning==2.2.1

I'm having the same warning UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step() when I set precision='16'

But the warning disappears when I set: precision='16-mixed'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working lr scheduler precision: amp Automatic Mixed Precision priority: 2 Low priority task
Projects
None yet