Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Merge branch 'master' into vision
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Dec 19, 2020
2 parents 85d38ff + 6a8d425 commit 1c72a30
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w

## Unreleased (1.x branch)

### Added

- Added a new learning rate scheduler: `CombinedLearningRateScheduler`. This can be used to combine different LR schedulers, using one after the other.


## [v1.3.0](https://github.com/allenai/allennlp/releases/tag/v1.3.0) - 2020-12-15

Expand Down
1 change: 1 addition & 0 deletions allennlp/training/learning_rate_schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ExponentialLearningRateScheduler,
ReduceOnPlateauLearningRateScheduler,
)
from allennlp.training.learning_rate_schedulers.combined import CombinedLearningRateScheduler
from allennlp.training.learning_rate_schedulers.cosine import CosineWithRestarts
from allennlp.training.learning_rate_schedulers.noam import NoamLR
from allennlp.training.learning_rate_schedulers.slanted_triangular import SlantedTriangular
Expand Down
121 changes: 121 additions & 0 deletions allennlp/training/learning_rate_schedulers/combined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from typing import Dict, Any, List, Tuple, Optional

from overrides import overrides
import torch

from allennlp.common.lazy import Lazy
from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler


@LearningRateScheduler.register("combined")
class CombinedLearningRateScheduler(LearningRateScheduler):
"""
This `LearningRateScheduler` can be used to apply an arbitrary number of other schedulers
one after the other.
These schedulers are defined though the `schedulers` parameter, which takes a list
of `Tuple[int, Lazy[LearningRateScheduler]]`. The first field of the tuple, the `int`,
specifies how many epochs the corresponding scheduler will be used before the next
scheduler takes its place.
While it usually makes sense for the sum
```python
sum(n_epochs for (n_epochs, _) in schedulers)
```
to equal the total number of training epochs, it is not a requirement.
If training continues beyond the last defined scheduler, both `step()` and `step_batch()`
will be a no-op.
"""

def __init__(
self,
optimizer: torch.optim.Optimizer,
schedulers: List[Tuple[int, Lazy[LearningRateScheduler]]],
num_steps_per_epoch: Optional[int] = None,
last_epoch: int = -1,
) -> None:
super().__init__(optimizer, last_epoch=last_epoch)
self.num_steps_per_epoch = num_steps_per_epoch
self.schedulers = schedulers
# This is used to know when we need to update `self._current_scheduler`
# by comparing it to `self.last_epoch`, and so to start with it needs to
# not equal `self.last_epoch`.
self._last_epoch_updated = -2
self._current_scheduler: Optional[LearningRateScheduler] = None
self._current_scheduler_first_epoch: Optional[int] = None
# We call this here in order to initialize the current scheduler now, since some schedulers
# modify the LR when they are initialized.
self.current_scheduler

@property
def current_scheduler(self) -> Optional[LearningRateScheduler]:
if self._last_epoch_updated != self.last_epoch:
current_epoch = self.last_epoch + 1
scheduler_first_epoch, scheduler_last_epoch = 0, -1
for scheduler_epochs, lazy_scheduler in self.schedulers:
scheduler_last_epoch += scheduler_epochs

# Is it time for a new scheduler?
if current_epoch == scheduler_first_epoch or (
self._current_scheduler_first_epoch != scheduler_first_epoch
and scheduler_first_epoch <= current_epoch <= scheduler_last_epoch
):
# Reset the base values of the LR to whatever they're currently at.
for group in self.optimizer.param_groups:
group[self._initial_param_group_field] = group[self.param_group_field]
self._current_scheduler = lazy_scheduler.construct(
optimizer=self.optimizer,
num_epochs=scheduler_epochs,
num_steps_per_epoch=self.num_steps_per_epoch,
)
self._current_scheduler_first_epoch = scheduler_first_epoch
break

scheduler_first_epoch = scheduler_last_epoch + 1
else:
# If we didn't break out of the loop, then we might have trained past
# the last defined scheduler, so we're not going to use a scheduler anymore.
if current_epoch > scheduler_last_epoch:
self._current_scheduler = None
self._last_epoch_updated = self.last_epoch
return self._current_scheduler

@overrides
def state_dict(self) -> Dict[str, Any]:
current_scheduler = self.current_scheduler
return {
"last_epoch": self.last_epoch,
"num_steps_per_epoch": self.num_steps_per_epoch,
"current_scheduler": None
if current_scheduler is None
else current_scheduler.state_dict(),
}

@overrides
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.last_epoch = state_dict["last_epoch"]
self.num_steps_per_epoch = state_dict["num_steps_per_epoch"]
if self.current_scheduler is not None:
assert state_dict["current_scheduler"] is not None
self.current_scheduler.load_state_dict(state_dict["current_scheduler"])

@overrides
def get_values(self):
"""
This should never be called directly.
"""
raise NotImplementedError

@overrides
def step_batch(self, batch_num_total: int = None) -> None:
if self.current_scheduler is not None:
self.current_scheduler.step_batch(batch_num_total)

@overrides
def step(self, metric: float = None) -> None:
self.last_epoch += 1
self.metric = metric
if self.current_scheduler is not None:
self.current_scheduler.step(metric)
2 changes: 1 addition & 1 deletion allennlp/training/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Scheduler:
During training using the AllenNLP `Trainer`, this is the API and calling
sequence for `step` and `step_batch`::
scheduler = ... # creates scheduler, calls self.step(last_epoch=-1) in __init__
scheduler = ... # creates scheduler
batch_num_total = 0
for epoch in range(num_epochs):
Expand Down
104 changes: 104 additions & 0 deletions tests/training/learning_rate_schedulers/combined_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import torch

from allennlp.common.testing import AllenNlpTestCase
from allennlp.common.params import Params
from allennlp.training.learning_rate_schedulers import (
LearningRateScheduler,
CombinedLearningRateScheduler,
PolynomialDecay,
)
from allennlp.training.optimizers import Optimizer


class TestCombinedLRScheduler(AllenNlpTestCase):
def setup_method(self):
super().setup_method()
self.model = torch.nn.Sequential(torch.nn.Linear(10, 10))
self.optimizer = Optimizer.from_params(
model_parameters=self.model.named_parameters(),
params=Params({"type": "sgd", "lr": 1.0}),
)

def get_scheduler(self) -> LearningRateScheduler:
return LearningRateScheduler.from_params(
Params(
{
"type": "combined",
"schedulers": [
[
2,
{
"type": "polynomial_decay",
"warmup_steps": 10,
"end_learning_rate": 0.5,
},
],
[
5,
{
"type": "polynomial_decay",
"warmup_steps": 0,
"end_learning_rate": 0.1,
},
],
],
}
),
optimizer=self.optimizer,
num_steps_per_epoch=10,
)

def test_partial_schedule(self):
scheduler = self.get_scheduler()
assert isinstance(scheduler, CombinedLearningRateScheduler)
assert isinstance(scheduler._current_scheduler, PolynomialDecay)

# This should be 0 because the PolynomialDecay scheduler initializes the LR to 0.
assert self.optimizer.param_groups[0]["lr"] == 0.0

epoch_end_lrs = []
for epoch in range(10):
if epoch > 6:
assert scheduler._current_scheduler is None
elif epoch >= 2:
assert scheduler._current_scheduler is not None
assert scheduler._current_scheduler.total_steps == 50
assert scheduler._current_scheduler.base_values[0] == 0.5
else:
assert scheduler._current_scheduler is not None
assert scheduler._current_scheduler.total_steps == 20
assert scheduler._current_scheduler.base_values[0] == 1.0

for step in range(10):
scheduler.step_batch()

scheduler.step()

epoch_end_lrs.append(self.optimizer.param_groups[0]["lr"])

assert epoch_end_lrs[0] == 1.0
assert epoch_end_lrs[1] == 0.5
assert epoch_end_lrs[6] == 0.1
assert epoch_end_lrs[6] == 0.1

def test_load_from_checkpoint(self):
scheduler = self.get_scheduler()

for epoch in range(3):
for step in range(10):
scheduler.step_batch()
scheduler.step()

assert scheduler.last_epoch == 2
assert scheduler._current_scheduler is not None
assert scheduler._current_scheduler.total_steps == 50
assert scheduler._current_scheduler.base_values[0] == 0.5

state_dict = scheduler.state_dict()
new_scheduler = self.get_scheduler()
new_scheduler.load_state_dict(state_dict)

assert new_scheduler.last_epoch == 2
assert new_scheduler._current_scheduler is not None
assert new_scheduler._current_scheduler.total_steps == 50
assert new_scheduler._current_scheduler.base_values[0] == 0.5, state_dict

0 comments on commit 1c72a30

Please sign in to comment.