Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Change registered names of scheduler callbacks (#2964)
Browse files Browse the repository at this point in the history
* change registered names of scheduler callbacks

* change file names

* change class names

* fix docs
  • Loading branch information
epwalsh authored and joelgrus committed Jun 18, 2019
1 parent 2a88450 commit 2a59be3
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
12 changes: 6 additions & 6 deletions allennlp/tests/training/callback_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from allennlp.training.callbacks import (
Events,
LogToTensorboard, CheckpointCallback, MovingAverageCallback, Validate, PostToUrl,
LrsCallback, MomentumSchedulerCallback, TrackMetrics, TrainSupervised, GenerateTrainingBatches
UpdateLearningRate, UpdateMomentum, TrackMetrics, TrainSupervised, GenerateTrainingBatches
)
from allennlp.training.checkpointer import Checkpointer
from allennlp.training.learning_rate_schedulers import LearningRateScheduler
Expand Down Expand Up @@ -458,7 +458,7 @@ def test_should_stop_early_with_invalid_patience(self):
def test_trainer_can_run_and_resume_with_momentum_scheduler(self):
scheduler = MomentumScheduler.from_params(
self.optimizer, Params({"type": "inverted_triangular", "cool_down": 2, "warm_up": 2}))
callbacks = self.default_callbacks() + [MomentumSchedulerCallback(scheduler)]
callbacks = self.default_callbacks() + [UpdateMomentum(scheduler)]
trainer = CallbackTrainer(model=self.model,
optimizer=self.optimizer,
num_epochs=4,
Expand All @@ -468,7 +468,7 @@ def test_trainer_can_run_and_resume_with_momentum_scheduler(self):

new_scheduler = MomentumScheduler.from_params(
self.optimizer, Params({"type": "inverted_triangular", "cool_down": 2, "warm_up": 2}))
new_callbacks = self.default_callbacks() + [MomentumSchedulerCallback(new_scheduler)]
new_callbacks = self.default_callbacks() + [UpdateMomentum(new_scheduler)]
new_trainer = CallbackTrainer(model=self.model,
optimizer=self.optimizer,
num_epochs=6,
Expand All @@ -482,7 +482,7 @@ def test_trainer_can_run_and_resume_with_momentum_scheduler(self):
def test_trainer_can_run_with_lr_scheduler(self):
lr_params = Params({"type": "reduce_on_plateau"})
lr_scheduler = LearningRateScheduler.from_params(self.optimizer, lr_params)
callbacks = self.default_callbacks() + [LrsCallback(lr_scheduler)]
callbacks = self.default_callbacks() + [UpdateLearningRate(lr_scheduler)]

trainer = CallbackTrainer(model=self.model,
optimizer=self.optimizer,
Expand All @@ -493,7 +493,7 @@ def test_trainer_can_run_with_lr_scheduler(self):
def test_trainer_can_resume_with_lr_scheduler(self):
lr_scheduler = LearningRateScheduler.from_params(
self.optimizer, Params({"type": "exponential", "gamma": 0.5}))
callbacks = self.default_callbacks() + [LrsCallback(lr_scheduler)]
callbacks = self.default_callbacks() + [UpdateLearningRate(lr_scheduler)]

trainer = CallbackTrainer(model=self.model,
optimizer=self.optimizer,
Expand All @@ -503,7 +503,7 @@ def test_trainer_can_resume_with_lr_scheduler(self):

new_lr_scheduler = LearningRateScheduler.from_params(
self.optimizer, Params({"type": "exponential", "gamma": 0.5}))
callbacks = self.default_callbacks() + [LrsCallback(new_lr_scheduler)]
callbacks = self.default_callbacks() + [UpdateLearningRate(new_lr_scheduler)]

new_trainer = CallbackTrainer(model=self.model,
optimizer=self.optimizer,
Expand Down
4 changes: 2 additions & 2 deletions allennlp/training/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from allennlp.training.callbacks.events import Events

from allennlp.training.callbacks.log_to_tensorboard import LogToTensorboard
from allennlp.training.callbacks.learning_rate_scheduler import LrsCallback
from allennlp.training.callbacks.momentum_scheduler import MomentumSchedulerCallback
from allennlp.training.callbacks.update_learning_rate import UpdateLearningRate
from allennlp.training.callbacks.update_momentum import UpdateMomentum
from allennlp.training.callbacks.checkpoint import CheckpointCallback
from allennlp.training.callbacks.moving_average import MovingAverageCallback
from allennlp.training.callbacks.validate import Validate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from allennlp.training.callback_trainer import CallbackTrainer # pylint:disable=unused-import


@Callback.register("learning_rate_scheduler")
class LrsCallback(Callback):
@Callback.register("update_learning_rate")
class UpdateLearningRate(Callback):
"""
Callback that runs the learning rate scheduler.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from allennlp.training.callback_trainer import CallbackTrainer # pylint:disable=unused-import


@Callback.register("momentum_scheduler")
class MomentumSchedulerCallback(Callback):
@Callback.register("update_momentum")
class UpdateMomentum(Callback):
"""
Callback that runs a Momentum Scheduler.
Expand Down
4 changes: 2 additions & 2 deletions doc/api/allennlp.training.callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ allennlp.training.callbacks
:undoc-members:
:show-inheritance:

.. automodule:: allennlp.training.callbacks.learning_rate_scheduler
.. automodule:: allennlp.training.callbacks.update_learning_rate
:members:
:undoc-members:
:show-inheritance:
Expand All @@ -37,7 +37,7 @@ allennlp.training.callbacks
:undoc-members:
:show-inheritance:

.. automodule:: allennlp.training.callbacks.momentum_scheduler
.. automodule:: allennlp.training.callbacks.update_momentum
:members:
:undoc-members:
:show-inheritance:
Expand Down

0 comments on commit 2a59be3

Please sign in to comment.