Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong LR scheduler behaviour when using AMP #16228

Closed
milesial opened this issue Jan 3, 2023 · 1 comment · May be fixed by #16229
Closed

Wrong LR scheduler behaviour when using AMP #16228

milesial opened this issue Jan 3, 2023 · 1 comment · May be fixed by #16229
Labels
duplicate This issue or pull request already exists

Comments

@milesial
Copy link

milesial commented Jan 3, 2023

Bug description

When training with native AMP and a LR scheduler, we get a warning that indicates that a LR step has been taken when an optimizer step was skipped (expected at the beginning of the training with native AMP):

This can be fixed by wrapping these lines https://github.com/Lightning-AI/lightning/blob/574a9516012b4ab778254055c537f5d57e8e694f/src/pytorch_lightning/core/module.py#L1589-L1592

in if hasattr(optimizer, '_step_count') and optimizer._step_count > 0.

Fix proposed in #16229

How to reproduce the bug

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum() * 100000000
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        opt = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        sched = torch.optim.lr_scheduler.StepLR(opt, 10)
        return {"optimizer": opt, 'lr_scheduler': {"scheduler": sched, "interval": "step"}}

def run():
    train_data = DataLoader(RandomDataset(32, 32), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        accelerator='gpu',
        devices=1,
        default_root_dir=os.getcwd(),
        num_sanity_val_steps=0,
        max_epochs=1,
        precision=16,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

Error messages and logs

/usr/local/lib/python3.8/dist-packages/torch/optim/lr_scheduler.py:138: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

Environment

No response

More info

No response

@milesial milesial added the needs triage Waiting to be triaged by maintainers label Jan 3, 2023
@carmocca
Copy link
Contributor

carmocca commented Jan 3, 2023

Duplicate of #5558 which suggests workarounds like #5558 (comment). Feel free to comment there

@carmocca carmocca closed this as not planned Won't fix, can't repro, duplicate, stale Jan 3, 2023
@carmocca carmocca added duplicate This issue or pull request already exists and removed needs triage Waiting to be triaged by maintainers labels Jan 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
duplicate This issue or pull request already exists
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants