Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log LR using LearningRateMonitor even when LR Scheduler is not defined. #9786

Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Log `learning rate` using `LearningRateMonitor` callback even when no scheduler is used ([#7468](https://github.com/PyTorchLightning/pytorch-lightning/issues/7468))


- Register `ShardedTensor` state dict hooks in `LightningModule.__init__` if the pytorch version supports `ShardedTensor` ([#8944](https://github.com/PyTorchLightning/pytorch-lightning/pull/8944))

Expand Down
78 changes: 62 additions & 16 deletions pytorch_lightning/callbacks/lr_monitor.py
Expand Up @@ -106,18 +106,13 @@ def on_train_start(self, trainer, *args, **kwargs):
"Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
)

if not trainer.lr_schedulers:
rank_zero_warn(
"You are using `LearningRateMonitor` callback with models that"
" have no learning rate schedulers. Please see documentation"
" for `configure_optimizers` method.",
RuntimeWarning,
)

if self.log_momentum:

def _check_no_key(key):
return any(key not in sch["scheduler"].optimizer.defaults for sch in trainer.lr_schedulers)
if trainer.lr_schedulers:
return any(key not in sch["scheduler"].optimizer.defaults for sch in trainer.lr_schedulers)

return any(key not in optimizer.defaults for optimizer in trainer.optimizers)

if _check_no_key("momentum") and _check_no_key("betas"):
rank_zero_warn(
Expand All @@ -127,7 +122,11 @@ def _check_no_key(key):
)

# Find names for schedulers
names = self._find_names(trainer.lr_schedulers)
names = (
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
self._find_names(trainer.lr_schedulers)
if trainer.lr_schedulers
else self._find_names_from_optimizer(trainer)
)

# Initialize for storing values
self.lrs = {name: [] for name in names}
Expand Down Expand Up @@ -155,12 +154,30 @@ def on_train_epoch_start(self, trainer, *args, **kwargs):
def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
latest_stat = {}

names = self._find_names(trainer.lr_schedulers, add_lr_sch_names=False)
self._remap_keys(names)

for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
if scheduler["interval"] == interval or interval == "any":
opt = scheduler["scheduler"].optimizer
if trainer.lr_schedulers:
names = self._find_names(trainer.lr_schedulers, add_lr_sch_names=False)
self._remap_keys(names)

for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
if interval in [scheduler["interval"], "any"]:
opt = scheduler["scheduler"].optimizer
param_groups = opt.param_groups
use_betas = "betas" in opt.defaults

for i, pg in enumerate(param_groups):
name_and_suffix = self._add_suffix(name, param_groups, i)
lr = self._extract_lr(pg, name_and_suffix)
latest_stat.update(lr)
momentum = self._extract_momentum(
param_group=pg, name=name_and_suffix.replace(name, f"{name}-momentum"), use_betas=use_betas
)
latest_stat.update(momentum)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
else:
names = self._find_names_from_optimizer(trainer, add_lr_sch_names=False)
self._remap_keys(names)

for idx, name in enumerate(names):
opt = trainer.optimizers[idx]
param_groups = opt.param_groups
use_betas = "betas" in opt.defaults

Expand Down Expand Up @@ -259,6 +276,35 @@ def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> Lis

return names

def _find_names_from_optimizer(self, trainer, add_lr_sch_names: bool = True) -> List[str]:
names = []
seen_optimizers = []
seen_optimizer_types = defaultdict(int)
for optimizer in trainer.optimizers:
name = "lr-" + optimizer.__class__.__name__

seen_optimizers.append(optimizer)
optimizer_cls = type(optimizer)
seen_optimizer_types[optimizer_cls] += 1

# Multiple param groups for the same scheduler
param_groups = optimizer.param_groups
duplicates = self._duplicate_param_group_names(param_groups)
if duplicates:
raise MisconfigurationException(
"A single `Optimizer` cannot have multiple parameter groups with identical "
f"`name` values. {name} has duplicated parameter group names {duplicates}"
)

name = self._add_prefix(name, optimizer_cls, seen_optimizer_types)

names.extend(self._add_suffix(name, param_groups, i) for i in range(len(param_groups)))

if add_lr_sch_names:
self.lr_sch_names.append(name)

rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
return names

@staticmethod
def _should_log(trainer) -> bool:
return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop
130 changes: 128 additions & 2 deletions tests/callbacks/test_lr_monitor.py
Expand Up @@ -122,7 +122,8 @@ def configure_optimizers(self):
), "Names of momentum values not set correctly"


def test_lr_monitor_no_lr_scheduler(tmpdir):
def test_lr_monitor_no_lr_scheduler_single_lr(tmpdir):
"""Test that learning rates are extracted and logged for no lr scheduler."""
tutils.reset_seed()

class CustomBoringModel(BoringModel):
Expand All @@ -137,10 +138,87 @@ def configure_optimizers(self):
default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, callbacks=[lr_monitor]
)

with pytest.warns(RuntimeWarning, match="have no learning rate schedulers"):
# with pytest.warns(RuntimeWarning, match="have no learning rate schedulers"):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

assert lr_monitor.lrs, "No learning rates logged"
assert len(lr_monitor.lrs) == len(
trainer.optimizers
), "Number of learning rates logged does not match number of optimizers"
assert lr_monitor.lr_sch_names == ["lr-SGD"], "Names of learning rates not set correctly"


@pytest.mark.parametrize("opt", ["SGD", "Adam"])
def test_lr_monitor_no_lr_scheduler_single_lr_with_momentum(tmpdir, opt: str):
"""Test that learning rates and momentum are extracted and logged for no lr scheduler."""

class LogMomentumModel(BoringModel):
def __init__(self, opt):
super().__init__()
self.opt = opt

def configure_optimizers(self):
if self.opt == "SGD":
opt_kwargs = {"momentum": 0.9}
elif self.opt == "Adam":
opt_kwargs = {"betas": (0.9, 0.999)}

optimizer = getattr(optim, self.opt)(self.parameters(), lr=1e-2, **opt_kwargs)
return [optimizer]
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

model = LogMomentumModel(opt=opt)
lr_monitor = LearningRateMonitor(log_momentum=True)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_val_batches=2,
limit_train_batches=5,
log_every_n_steps=1,
callbacks=[lr_monitor],
)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

assert all(v is not None for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
assert len(lr_monitor.last_momentum_values) == len(
trainer.optimizers
), "Number of momentum values logged does not match number of optimizers"
assert all(
k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values.keys()
), "Names of momentum values not set correctly"


def test_log_momentum_no_momentum_optimizer_no_lr_scheduler(tmpdir):
"""Test that if optimizer doesn't have momentum then a warning is raised with log_momentum=True."""

class LogMomentumModel(BoringModel):
def configure_optimizers(self):
optimizer = optim.ASGD(self.parameters(), lr=1e-2)
return [optimizer]

model = LogMomentumModel()
lr_monitor = LearningRateMonitor(log_momentum=True)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=2,
limit_train_batches=5,
log_every_n_steps=1,
callbacks=[lr_monitor],
)
with pytest.warns(RuntimeWarning, match="optimizers do not have momentum."):
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
assert len(lr_monitor.last_momentum_values) == len(
trainer.optimizers
), "Number of momentum values logged does not match number of optimizers"
assert all(
k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values.keys()
), "Names of momentum values not set correctly"


def test_lr_monitor_no_logger(tmpdir):
tutils.reset_seed()
Expand Down Expand Up @@ -204,6 +282,54 @@ def configure_optimizers(self):
), "Length of logged learning rates do not match the expected number"


@pytest.mark.parametrize("logging_interval", ["step", "epoch"])
def test_lr_monitor_no_lr_scheduler_multi_lrs(tmpdir, logging_interval: str):
"""Test that learning rates are extracted and logged for multi optimizers but no lr scheduler."""
tutils.reset_seed()

class CustomBoringModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx):
return super().training_step(batch, batch_idx)

def configure_optimizers(self):
optimizer1 = optim.Adam(self.parameters(), lr=1e-2)
optimizer2 = optim.Adam(self.parameters(), lr=1e-2)

return [optimizer1, optimizer2]

model = CustomBoringModel()
model.training_epoch_end = None

lr_monitor = LearningRateMonitor(logging_interval=logging_interval)
log_every_n_steps = 2

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
log_every_n_steps=log_every_n_steps,
limit_train_batches=7,
limit_val_batches=0.1,
callbacks=[lr_monitor],
)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

assert lr_monitor.lrs, "No learning rates logged"
assert len(lr_monitor.lrs) == len(
trainer.optimizers
), "Number of learning rates logged does not match number of optimizers"
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"

if logging_interval == "step":
expected_number_logged = trainer.global_step // log_every_n_steps
if logging_interval == "epoch":
expected_number_logged = trainer.max_epochs

assert all(
len(lr) == expected_number_logged for lr in lr_monitor.lrs.values()
), "Length of logged learning rates do not match the expected number"


def test_lr_monitor_param_groups(tmpdir):
"""Test that learning rates are extracted and logged for single lr scheduler."""
tutils.reset_seed()
Expand Down