Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Lr scheduler bug (#2905)
Browse files Browse the repository at this point in the history
Fixes #2895 
Couple of things that I would like to discuss

- Currently, the design makes it necessary to specify the mode while using the reduce_on_plateau scheduler, unless its specified in a trainer (in which case the mode is set automatically, based on the validation metric)
- If the metric and the mode do not match, currently the code uses a logger.warning instead of an exception. Can change that to be an exception.
  • Loading branch information
codedecde authored and brendan-ai2 committed Jul 12, 2019
1 parent 5c64f9d commit 083f343
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 12 deletions.
2 changes: 1 addition & 1 deletion allennlp/tests/training/callback_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def test_trainer_can_run_and_resume_with_momentum_scheduler(self):
new_trainer.train()

def test_trainer_can_run_with_lr_scheduler(self):
lr_params = Params({"type": "reduce_on_plateau"})
lr_params = Params({"type": "reduce_on_plateau", "mode": "min"})
lr_scheduler = LearningRateScheduler.from_params(self.optimizer, lr_params)
callbacks = self.default_callbacks() + [UpdateLearningRate(lr_scheduler)]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,26 @@ def setUp(self):
super().setUp()
self.model = torch.nn.Sequential(torch.nn.Linear(10, 10))

def test_reduce_on_plateau_error_throw_when_no_metrics_exist(self):
def test_reduce_on_plateau_error_throw_when_mode_not_specified(self):
with self.assertRaises(ConfigurationError) as context:
LearningRateScheduler.from_params(Optimizer.from_params(self.model.named_parameters(),
Params({"type": "adam"})),
Params({"type": "reduce_on_plateau"})).step(None, None)
assert "ReduceLROnPlateau requires a mode to be specified" in str(context.exception)

def test_reduce_on_plateau_error_throw_when_no_metrics_exist(self):
with self.assertRaises(ConfigurationError) as context:
LearningRateScheduler.from_params(Optimizer.from_params(self.model.named_parameters(),
Params({"type": "adam"})),
Params({"type": "reduce_on_plateau",
"mode": "min"})).step(None, None)
assert "learning rate scheduler requires a validation metric" in str(context.exception)

def test_reduce_on_plateau_works_when_metrics_exist(self):
LearningRateScheduler.from_params(Optimizer.from_params(self.model.named_parameters(),
Params({"type": "adam"})),
Params({"type": "reduce_on_plateau"})).step(10, None)
Params({"type": "reduce_on_plateau",
"mode": "max"})).step(10, None)

def test_no_metric_wrapper_can_support_none_for_metrics(self):
lrs = LearningRateScheduler.from_params(Optimizer.from_params(self.model.named_parameters(),
Expand Down
62 changes: 61 additions & 1 deletion allennlp/tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def test_trainer_can_run_and_resume_with_momentum_scheduler(self):
new_trainer.train()

def test_trainer_can_run_with_lr_scheduler(self):
lr_params = Params({"type": "reduce_on_plateau"})
lr_params = Params({"type": "reduce_on_plateau", "mode": "min"})
lr_scheduler = LearningRateScheduler.from_params(self.optimizer, lr_params)
trainer = Trainer(model=self.model,
optimizer=self.optimizer,
Expand All @@ -395,6 +395,66 @@ def test_trainer_can_run_with_lr_scheduler(self):
num_epochs=2)
trainer.train()

def test_reduce_on_plateau_and_metric_agree(self):
# pylint: disable=protected-access
for metric in ["+acc", "-loss"]:
trainer_params = Params({
"validation_metric": metric,
"learning_rate_scheduler": {
"type": "reduce_on_plateau"
},
"optimizer": {"type": "adam", "lr": 0.01}})
trainer = Trainer.from_params(model=self.model,
serialization_dir=self.TEST_DIR,
iterator=self.iterator,
train_data=self.instances,
validation_data=self.instances,
params=trainer_params)
if metric[0] == "+":
correct_mode = "max"
assert trainer._learning_rate_scheduler.lr_scheduler.mode == correct_mode
else:
correct_mode = "min"
assert trainer._learning_rate_scheduler.lr_scheduler.mode == correct_mode

def test_mode_specified_in_reduce_on_plateau(self):
# pylint: disable=protected-access
for mode, metric in [("min", "-custom"), ("max", "+custom")]:
trainer_params = Params({
"validation_metric": metric,
"learning_rate_scheduler": {
"type": "reduce_on_plateau",
"mode": mode
},
"optimizer": {"type": "adam", "lr": 0.01}})
trainer = Trainer.from_params(model=self.model,
serialization_dir=self.TEST_DIR,
iterator=self.iterator,
train_data=self.instances,
validation_data=self.instances,
params=trainer_params)
assert trainer._learning_rate_scheduler.lr_scheduler.mode == mode

def test_mode_doesnt_agree_with_metric(self):
# pylint: disable=protected-access
for mode, metric in [("max", "-custom"), ("min", "+custom")]:
trainer_params = Params({
"validation_metric": metric,
"learning_rate_scheduler": {
"type": "reduce_on_plateau",
"mode": mode
},
"optimizer": {"type": "adam", "lr": 0.01}})
with self.assertLogs(logger="allennlp.training.util", level="WARNING"):
# we warn when the metric and the mode don't agree
trainer = Trainer.from_params(model=self.model,
serialization_dir=self.TEST_DIR,
iterator=self.iterator,
train_data=self.instances,
validation_data=self.instances,
params=trainer_params)
assert trainer._learning_rate_scheduler.lr_scheduler.mode == mode

def test_trainer_can_resume_with_lr_scheduler(self):
lr_scheduler = LearningRateScheduler.from_params(
self.optimizer, Params({"type": "exponential", "gamma": 0.5}))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Any
from typing import Dict, Any, Optional

from overrides import overrides
import torch
Expand All @@ -13,8 +13,9 @@ class LearningRateScheduler(Scheduler, Registrable):

def __init__(self,
optimizer: torch.optim.Optimizer,
last_epoch: int = -1) -> None:
super().__init__(optimizer, "lr", last_epoch)
last_epoch: int = -1,
mode: Optional[str] = None) -> None:
super().__init__(optimizer, "lr", last_epoch, mode)

def get_values(self) -> None:
raise NotImplementedError
Expand All @@ -24,7 +25,13 @@ def get_values(self) -> None:
def from_params(cls, optimizer: torch.optim.Optimizer, params: Params): # type: ignore
# pylint: disable=arguments-differ
scheduler_type = params.pop_choice("type", LearningRateScheduler.list_available())
scheduler = LearningRateScheduler.by_name(scheduler_type)(optimizer, **params.as_dict()) # type: ignore
if scheduler_type == "reduce_on_plateau" and "mode" not in params:
raise ConfigurationError("ReduceLROnPlateau requires a mode to be specified."
" This ensures that there are no accidental side effects like"
" mode not being faithful to the metric being tracked")

scheduler = LearningRateScheduler.by_name(scheduler_type)(optimizer=optimizer,
**params.as_dict()) # type: ignore
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
return _PyTorchLearningRateSchedulerWithMetricsWrapper(scheduler)
elif isinstance(scheduler, torch.optim.lr_scheduler._LRScheduler): # pylint: disable=protected-access
Expand Down
6 changes: 4 additions & 2 deletions allennlp/training/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Any
from typing import Dict, Any, Optional

import torch

Expand Down Expand Up @@ -29,10 +29,12 @@ class Scheduler:
def __init__(self,
optimizer: torch.optim.Optimizer,
param_group_field: str,
last_epoch: int = -1) -> None:
last_epoch: int = -1,
mode: Optional[str] = None) -> None:
self.optimizer = optimizer
self.param_group_field = param_group_field
self._initial_param_group_field = f"initial_{param_group_field}"
self.mode = mode
if last_epoch == -1:
for i, group in enumerate(self.optimizer.param_groups):
if param_group_field not in group:
Expand Down
16 changes: 14 additions & 2 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,20 @@ def from_params(cls, # type: ignore
cuda_device = parse_cuda_device(params.pop("cuda_device", -1))
grad_norm = params.pop_float("grad_norm", None)
grad_clipping = params.pop_float("grad_clipping", None)
lr_scheduler_params = params.pop("learning_rate_scheduler", None)
momentum_scheduler_params = params.pop("momentum_scheduler", None)

if validation_metric[0] == "-":
should_decrease = True
elif validation_metric[0] == "+":
should_decrease = False
else:
raise ConfigurationError("metric_name must start with + or -")

lr_scheduler_params = training_util.update_scheduler_params(
params.pop("learning_rate_scheduler", None),
should_decrease)
momentum_scheduler_params = training_util.update_scheduler_params(
params.pop("momentum_scheduler", None),
should_decrease)

if isinstance(cuda_device, list):
model_device = cuda_device[0]
Expand Down
20 changes: 20 additions & 0 deletions allennlp/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def evaluate(model: Model,

return final_metrics


def description_from_metrics(metrics: Dict[str, float]) -> str:
if (not HasBeenWarned.tqdm_ignores_underscores and
any(metric_name.startswith("_") for metric_name in metrics)):
Expand All @@ -437,3 +438,22 @@ def description_from_metrics(metrics: Dict[str, float]) -> str:
return ', '.join(["%s: %.4f" % (name, value)
for name, value in
metrics.items() if not name.startswith("_")]) + " ||"


def update_scheduler_params(params: Optional[Params], should_decrease: bool) -> Params:
"""Updates the params to specify the mode, if not specified, based on the
if the validation metric is decreasing or not
"""
def _is_faithful(mode: str, should_decrease: bool) -> bool:
if mode not in ["min", "max"]:
raise ConfigurationError("mode should be min or max")
return bool((mode == "min" and should_decrease) or (mode == "max" and not should_decrease))

if params is not None:
if "mode" in params:
if not _is_faithful(params.get("mode"), should_decrease):
logger.warning("The mode for the scheduler and the metrics are not "
"faithful to each other.")
else:
params["mode"] = "min" if should_decrease else "max"
return params

0 comments on commit 083f343

Please sign in to comment.