Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

LearningRateMonitor KeyError with multiple parameter groups and no LR scheduler #10024

Closed
eladsegal opened this issue Oct 19, 2021 · 5 comments 路 Fixed by #10044
Closed

LearningRateMonitor KeyError with multiple parameter groups and no LR scheduler #10024

eladsegal opened this issue Oct 19, 2021 · 5 comments 路 Fixed by #10044
Assignees
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@eladsegal
Copy link
Contributor

馃悰 Bug

This bug is due to the change in #9786:
When there is an optimizer with multiple parameter groups and no learning rate scheduler,
_find_names_from_optimizers returns keys with suffix but then the suffix is added again in _get_lr_momentum_stat, resulting in a key error.
It only happens for multiple parameter groups, as for a single group _add_suffix doesn't add the suffix.

@eladsegal eladsegal added bug Something isn't working help wanted Open to be worked on labels Oct 19, 2021
@rohitgr7 rohitgr7 self-assigned this Oct 19, 2021
@kandluis
Copy link
Contributor

We have a few workflows that are facing this issue as well. What's the expected ETA on a fix?

@tangbinh
Copy link
Contributor

We're also having an internal workflow failing for the same reason and would expect more to follow. Please see the following script to reproduce the problem described by @eladsegal.

import os

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(32, 16), nn.Linear(16, 2))

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.Adam([
            {'params': self.net[0].parameters(), 'lr': 2e-4},
            {'params': self.net[1].parameters(), 'lr': 1e-3}
        ])


train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

model = BoringModel()
trainer = Trainer(
    default_root_dir=os.getcwd(),
    limit_train_batches=1,
    limit_val_batches=1,
    num_sanity_val_steps=0,
    max_epochs=1,
    callbacks=[LearningRateMonitor()]
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)

Here's the stack trace we got after running the script on master:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-17-788ac4a26c1a> in <module>()
     62     callbacks=[LearningRateMonitor()]
     63 )
---> 64 trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
     65 trainer.test(model, dataloaders=test_data)

17 frames
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/callbacks/lr_monitor.py in _extract_lr(self, param_group, name)
    210     def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]:
    211         lr = param_group.get("lr")
--> 212         self.lrs[name].append(lr)
    213         return {name: lr}
    214 

KeyError: 'lr-Adam/pg1/pg1'

@rohitgr7
Copy link
Contributor

thank you guys for raising this..
ETA: today (most probably) :)

@rohitgr7
Copy link
Contributor

hey!
created a fix here: #10044
can anyone confirm if it's working for them now?
thanks :)

@eladsegal
Copy link
Contributor Author

Hey, I can confirm the fix works.
Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
4 participants