Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Oct 14, 2021
1 parent 801d928 commit 5736d3a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 60 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added


- Log `learning_rate` using `LearningRateMonitor` callback even when no scheduler is used ([#9786](https://github.com/PyTorchLightning/pytorch-lightning/issues/9786))
- Add support for monitoring the learning rate monitor without schedulers in `LearningRateMonitor` ([#9786](https://github.com/PyTorchLightning/pytorch-lightning/issues/9786))


- Register `ShardedTensor` state dict hooks in `LightningModule.__init__` if the pytorch version supports `ShardedTensor` ([#8944](https://github.com/PyTorchLightning/pytorch-lightning/pull/8944))
Expand Down
76 changes: 17 additions & 59 deletions tests/callbacks/test_lr_monitor.py
Expand Up @@ -37,16 +37,11 @@ def test_lr_monitor_single_lr(tmpdir):
default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, callbacks=[lr_monitor]
)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

assert lr_monitor.lrs, "No learning rates logged"
assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default"
assert len(lr_monitor.lrs) == len(
trainer.lr_schedulers
), "Number of learning rates logged does not match number of lr schedulers"
assert (
lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["lr-SGD"]
), "Names of learning rates not set correctly"
assert len(lr_monitor.lrs) == len(trainer.lr_schedulers)
assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["lr-SGD"]


@pytest.mark.parametrize("opt", ["SGD", "Adam"])
Expand Down Expand Up @@ -79,15 +74,10 @@ def configure_optimizers(self):
callbacks=[lr_monitor],
)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

assert all(v is not None for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
assert len(lr_monitor.last_momentum_values) == len(
trainer.lr_schedulers
), "Number of momentum values logged does not match number of lr schedulers"
assert all(
k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values.keys()
), "Names of momentum values not set correctly"
assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers)
assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values.keys())


def test_log_momentum_no_momentum_optimizer(tmpdir):
Expand All @@ -111,15 +101,10 @@ def configure_optimizers(self):
)
with pytest.warns(RuntimeWarning, match="optimizers do not have momentum."):
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
assert len(lr_monitor.last_momentum_values) == len(
trainer.lr_schedulers
), "Number of momentum values logged does not match number of lr schedulers"
assert all(
k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values.keys()
), "Names of momentum values not set correctly"
assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers)
assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values.keys())


def test_lr_monitor_no_lr_scheduler_single_lr(tmpdir):
Expand All @@ -138,15 +123,11 @@ def configure_optimizers(self):
default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, callbacks=[lr_monitor]
)

# with pytest.warns(RuntimeWarning, match="have no learning rate schedulers"):
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

assert lr_monitor.lrs, "No learning rates logged"
assert len(lr_monitor.lrs) == len(
trainer.optimizers
), "Number of learning rates logged does not match number of optimizers"
assert lr_monitor.lr_sch_names == ["lr-SGD"], "Names of learning rates not set correctly"
assert len(lr_monitor.lrs) == len(trainer.optimizers)
assert lr_monitor.lr_sch_names == ["lr-SGD"]


@pytest.mark.parametrize("opt", ["SGD", "Adam"])
Expand Down Expand Up @@ -178,15 +159,10 @@ def configure_optimizers(self):
callbacks=[lr_monitor],
)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

assert all(v is not None for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
assert len(lr_monitor.last_momentum_values) == len(
trainer.optimizers
), "Number of momentum values logged does not match number of optimizers"
assert all(
k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values.keys()
), "Names of momentum values not set correctly"
assert len(lr_monitor.last_momentum_values) == len(trainer.optimizers)
assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values.keys())


def test_log_momentum_no_momentum_optimizer_no_lr_scheduler(tmpdir):
Expand All @@ -209,15 +185,10 @@ def configure_optimizers(self):
)
with pytest.warns(RuntimeWarning, match="optimizers do not have momentum."):
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
assert len(lr_monitor.last_momentum_values) == len(
trainer.optimizers
), "Number of momentum values logged does not match number of optimizers"
assert all(
k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values.keys()
), "Names of momentum values not set correctly"
assert len(lr_monitor.last_momentum_values) == len(trainer.optimizers)
assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values.keys())


def test_lr_monitor_no_logger(tmpdir):
Expand Down Expand Up @@ -264,22 +235,17 @@ def configure_optimizers(self):
callbacks=[lr_monitor],
)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

assert lr_monitor.lrs, "No learning rates logged"
assert len(lr_monitor.lrs) == len(
trainer.lr_schedulers
), "Number of learning rates logged does not match number of lr schedulers"
assert len(lr_monitor.lrs) == len(trainer.lr_schedulers)
assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"

if logging_interval == "step":
expected_number_logged = trainer.global_step // log_every_n_steps
if logging_interval == "epoch":
expected_number_logged = trainer.max_epochs

assert all(
len(lr) == expected_number_logged for lr in lr_monitor.lrs.values()
), "Length of logged learning rates do not match the expected number"
assert all(len(lr) == expected_number_logged for lr in lr_monitor.lrs.values())


@pytest.mark.parametrize("logging_interval", ["step", "epoch"])
Expand Down Expand Up @@ -312,22 +278,17 @@ def configure_optimizers(self):
callbacks=[lr_monitor],
)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

assert lr_monitor.lrs, "No learning rates logged"
assert len(lr_monitor.lrs) == len(
trainer.optimizers
), "Number of learning rates logged does not match number of optimizers"
assert len(lr_monitor.lrs) == len(trainer.optimizers)
assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"

if logging_interval == "step":
expected_number_logged = trainer.global_step // log_every_n_steps
if logging_interval == "epoch":
expected_number_logged = trainer.max_epochs

assert all(
len(lr) == expected_number_logged for lr in lr_monitor.lrs.values()
), "Length of logged learning rates do not match the expected number"
assert all(len(lr) == expected_number_logged for lr in lr_monitor.lrs.values())


def test_lr_monitor_param_groups(tmpdir):
Expand All @@ -353,12 +314,9 @@ def configure_optimizers(self):
default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, callbacks=[lr_monitor]
)
trainer.fit(model, datamodule=dm)
assert trainer.state.finished, f"Training failed with {trainer.state}"

assert lr_monitor.lrs, "No learning rates logged"
assert len(lr_monitor.lrs) == 2 * len(
trainer.lr_schedulers
), "Number of learning rates logged does not match number of param groups"
assert len(lr_monitor.lrs) == 2 * len(trainer.lr_schedulers)
assert lr_monitor.lr_sch_names == ["lr-Adam"]
assert list(lr_monitor.lrs.keys()) == ["lr-Adam/pg1", "lr-Adam/pg2"], "Names of learning rates not set correctly"

Expand Down

0 comments on commit 5736d3a

Please sign in to comment.