This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
231 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from typing import Dict, Any, List, Tuple, Optional | ||
|
||
from overrides import overrides | ||
import torch | ||
|
||
from allennlp.common.lazy import Lazy | ||
from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler | ||
|
||
|
||
@LearningRateScheduler.register("combined") | ||
class CombinedLearningRateScheduler(LearningRateScheduler): | ||
""" | ||
This `LearningRateScheduler` can be used to apply an arbitrary number of other schedulers | ||
one after the other. | ||
These schedulers are defined though the `schedulers` parameter, which takes a list | ||
of `Tuple[int, Lazy[LearningRateScheduler]]`. The first field of the tuple, the `int`, | ||
specifies how many epochs the corresponding scheduler will be used before the next | ||
scheduler takes its place. | ||
While it usually makes sense for the sum | ||
```python | ||
sum(n_epochs for (n_epochs, _) in schedulers) | ||
``` | ||
to equal the total number of training epochs, it is not a requirement. | ||
If training continues beyond the last defined scheduler, both `step()` and `step_batch()` | ||
will be a no-op. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
optimizer: torch.optim.Optimizer, | ||
schedulers: List[Tuple[int, Lazy[LearningRateScheduler]]], | ||
num_steps_per_epoch: Optional[int] = None, | ||
last_epoch: int = -1, | ||
) -> None: | ||
super().__init__(optimizer, last_epoch=last_epoch) | ||
self.num_steps_per_epoch = num_steps_per_epoch | ||
self.schedulers = schedulers | ||
# This is used to know when we need to update `self._current_scheduler` | ||
# by comparing it to `self.last_epoch`, and so to start with it needs to | ||
# not equal `self.last_epoch`. | ||
self._last_epoch_updated = -2 | ||
self._current_scheduler: Optional[LearningRateScheduler] = None | ||
self._current_scheduler_first_epoch: Optional[int] = None | ||
# We call this here in order to initialize the current scheduler now, since some schedulers | ||
# modify the LR when they are initialized. | ||
self.current_scheduler | ||
|
||
@property | ||
def current_scheduler(self) -> Optional[LearningRateScheduler]: | ||
if self._last_epoch_updated != self.last_epoch: | ||
current_epoch = self.last_epoch + 1 | ||
scheduler_first_epoch, scheduler_last_epoch = 0, -1 | ||
for scheduler_epochs, lazy_scheduler in self.schedulers: | ||
scheduler_last_epoch += scheduler_epochs | ||
|
||
# Is it time for a new scheduler? | ||
if current_epoch == scheduler_first_epoch or ( | ||
self._current_scheduler_first_epoch != scheduler_first_epoch | ||
and scheduler_first_epoch <= current_epoch <= scheduler_last_epoch | ||
): | ||
# Reset the base values of the LR to whatever they're currently at. | ||
for group in self.optimizer.param_groups: | ||
group[self._initial_param_group_field] = group[self.param_group_field] | ||
self._current_scheduler = lazy_scheduler.construct( | ||
optimizer=self.optimizer, | ||
num_epochs=scheduler_epochs, | ||
num_steps_per_epoch=self.num_steps_per_epoch, | ||
) | ||
self._current_scheduler_first_epoch = scheduler_first_epoch | ||
break | ||
|
||
scheduler_first_epoch = scheduler_last_epoch + 1 | ||
else: | ||
# If we didn't break out of the loop, then we might have trained past | ||
# the last defined scheduler, so we're not going to use a scheduler anymore. | ||
if current_epoch > scheduler_last_epoch: | ||
self._current_scheduler = None | ||
self._last_epoch_updated = self.last_epoch | ||
return self._current_scheduler | ||
|
||
@overrides | ||
def state_dict(self) -> Dict[str, Any]: | ||
current_scheduler = self.current_scheduler | ||
return { | ||
"last_epoch": self.last_epoch, | ||
"num_steps_per_epoch": self.num_steps_per_epoch, | ||
"current_scheduler": None | ||
if current_scheduler is None | ||
else current_scheduler.state_dict(), | ||
} | ||
|
||
@overrides | ||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
self.last_epoch = state_dict["last_epoch"] | ||
self.num_steps_per_epoch = state_dict["num_steps_per_epoch"] | ||
if self.current_scheduler is not None: | ||
assert state_dict["current_scheduler"] is not None | ||
self.current_scheduler.load_state_dict(state_dict["current_scheduler"]) | ||
|
||
@overrides | ||
def get_values(self): | ||
""" | ||
This should never be called directly. | ||
""" | ||
raise NotImplementedError | ||
|
||
@overrides | ||
def step_batch(self, batch_num_total: int = None) -> None: | ||
if self.current_scheduler is not None: | ||
self.current_scheduler.step_batch(batch_num_total) | ||
|
||
@overrides | ||
def step(self, metric: float = None) -> None: | ||
self.last_epoch += 1 | ||
self.metric = metric | ||
if self.current_scheduler is not None: | ||
self.current_scheduler.step(metric) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
104 changes: 104 additions & 0 deletions
104
tests/training/learning_rate_schedulers/combined_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import torch | ||
|
||
from allennlp.common.testing import AllenNlpTestCase | ||
from allennlp.common.params import Params | ||
from allennlp.training.learning_rate_schedulers import ( | ||
LearningRateScheduler, | ||
CombinedLearningRateScheduler, | ||
PolynomialDecay, | ||
) | ||
from allennlp.training.optimizers import Optimizer | ||
|
||
|
||
class TestCombinedLRScheduler(AllenNlpTestCase): | ||
def setup_method(self): | ||
super().setup_method() | ||
self.model = torch.nn.Sequential(torch.nn.Linear(10, 10)) | ||
self.optimizer = Optimizer.from_params( | ||
model_parameters=self.model.named_parameters(), | ||
params=Params({"type": "sgd", "lr": 1.0}), | ||
) | ||
|
||
def get_scheduler(self) -> LearningRateScheduler: | ||
return LearningRateScheduler.from_params( | ||
Params( | ||
{ | ||
"type": "combined", | ||
"schedulers": [ | ||
[ | ||
2, | ||
{ | ||
"type": "polynomial_decay", | ||
"warmup_steps": 10, | ||
"end_learning_rate": 0.5, | ||
}, | ||
], | ||
[ | ||
5, | ||
{ | ||
"type": "polynomial_decay", | ||
"warmup_steps": 0, | ||
"end_learning_rate": 0.1, | ||
}, | ||
], | ||
], | ||
} | ||
), | ||
optimizer=self.optimizer, | ||
num_steps_per_epoch=10, | ||
) | ||
|
||
def test_partial_schedule(self): | ||
scheduler = self.get_scheduler() | ||
assert isinstance(scheduler, CombinedLearningRateScheduler) | ||
assert isinstance(scheduler._current_scheduler, PolynomialDecay) | ||
|
||
# This should be 0 because the PolynomialDecay scheduler initializes the LR to 0. | ||
assert self.optimizer.param_groups[0]["lr"] == 0.0 | ||
|
||
epoch_end_lrs = [] | ||
for epoch in range(10): | ||
if epoch > 6: | ||
assert scheduler._current_scheduler is None | ||
elif epoch >= 2: | ||
assert scheduler._current_scheduler is not None | ||
assert scheduler._current_scheduler.total_steps == 50 | ||
assert scheduler._current_scheduler.base_values[0] == 0.5 | ||
else: | ||
assert scheduler._current_scheduler is not None | ||
assert scheduler._current_scheduler.total_steps == 20 | ||
assert scheduler._current_scheduler.base_values[0] == 1.0 | ||
|
||
for step in range(10): | ||
scheduler.step_batch() | ||
|
||
scheduler.step() | ||
|
||
epoch_end_lrs.append(self.optimizer.param_groups[0]["lr"]) | ||
|
||
assert epoch_end_lrs[0] == 1.0 | ||
assert epoch_end_lrs[1] == 0.5 | ||
assert epoch_end_lrs[6] == 0.1 | ||
assert epoch_end_lrs[6] == 0.1 | ||
|
||
def test_load_from_checkpoint(self): | ||
scheduler = self.get_scheduler() | ||
|
||
for epoch in range(3): | ||
for step in range(10): | ||
scheduler.step_batch() | ||
scheduler.step() | ||
|
||
assert scheduler.last_epoch == 2 | ||
assert scheduler._current_scheduler is not None | ||
assert scheduler._current_scheduler.total_steps == 50 | ||
assert scheduler._current_scheduler.base_values[0] == 0.5 | ||
|
||
state_dict = scheduler.state_dict() | ||
new_scheduler = self.get_scheduler() | ||
new_scheduler.load_state_dict(state_dict) | ||
|
||
assert new_scheduler.last_epoch == 2 | ||
assert new_scheduler._current_scheduler is not None | ||
assert new_scheduler._current_scheduler.total_steps == 50 | ||
assert new_scheduler._current_scheduler.base_values[0] == 0.5, state_dict |