Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Fixes recording validation metrics for learning rate schedulers that …
Browse files Browse the repository at this point in the history
…rely on it (#4959)

* Fixes recording validation metrics for learning rate schedulers that rely on it

* Test for learning rate schedulers that take metrics

* Changelog

* Make mypy happy
  • Loading branch information
dirkgr committed Feb 4, 2021
1 parent 4535f5c commit c418f84
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 11 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -14,8 +14,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Learning rate schedulers that rely on metrics from the validation set were broken in v2.0.0. This
brings that functionality back.
- Fixed a bug where the `MultiProcessDataLoading` would crash when `num_workers > 0`, `start_method = "spawn"`, `max_instances_in_memory not None`, and `batches_per_epoch not None`.


## [v2.0.1](https://github.com/allenai/allennlp/releases/tag/v2.0.1) - 2021-01-29

### Added
Expand Down
21 changes: 12 additions & 9 deletions allennlp/training/metric_tracker.py
Expand Up @@ -92,15 +92,7 @@ def add_metrics(self, metrics: Dict[str, float]) -> None:
"""
Record a new value of the metric and update the various things that depend on it.
"""
try:
combined_score = sum(
factor * metrics[metric_name] for factor, metric_name in self.tracked_metrics
)
except KeyError as e:
raise ConfigurationError(
f"You configured the trainer to use the {e.args[0]}"
"metric for early stopping, but the model did not produce that metric."
)
combined_score = self.combined_score(metrics)

new_best = (self._best_so_far is None) or (combined_score > self._best_so_far)

Expand Down Expand Up @@ -128,3 +120,14 @@ def should_stop_early(self) -> bool:
return False
else:
return self._epochs_with_no_improvement >= self._patience

def combined_score(self, metrics: Dict[str, float]) -> float:
try:
return sum(
factor * metrics[metric_name] for factor, metric_name in self.tracked_metrics
)
except KeyError as e:
raise ConfigurationError(
f"You configured the trainer to use the {e.args[0]}"
"metric for early stopping, but the model did not produce that metric."
)
4 changes: 2 additions & 2 deletions allennlp/training/trainer.py
Expand Up @@ -949,7 +949,6 @@ def _try_train(self) -> Tuple[Dict[str, Any], int]:
logger.info("Beginning training.")

val_metrics: Dict[str, float] = {}
this_epoch_val_metric: float = 0.0
metrics: Dict[str, Any] = {}
epochs_trained = 0
training_start_time = time.time()
Expand All @@ -976,6 +975,7 @@ def _try_train(self) -> Tuple[Dict[str, Any], int]:
elif key.startswith("worker_") and key.endswith("_memory_MB"):
metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)

this_epoch_val_metric: float = 0.0
if self._validation_data_loader is not None:
with torch.no_grad():
# We have a validation set, so compute all the metrics on it.
Expand All @@ -999,8 +999,8 @@ def _try_train(self) -> Tuple[Dict[str, Any], int]:
)

# Check validation metric for early stopping
this_epoch_val_metric = self._metric_tracker.combined_score(val_metrics)
self._metric_tracker.add_metrics(val_metrics)

if self._metric_tracker.should_stop_early():
logger.info("Ran out of patience. Stopping training.")
break
Expand Down
27 changes: 27 additions & 0 deletions tests/training/trainer_test.py
Expand Up @@ -31,6 +31,7 @@
from allennlp.training.learning_rate_schedulers import ExponentialLearningRateScheduler
from allennlp.training.momentum_schedulers import MomentumScheduler
from allennlp.training.moving_average import ExponentialMovingAverage
from allennlp.training.optimizers import Optimizer


class TrainerTestBase(AllenNlpTestCase):
Expand Down Expand Up @@ -557,6 +558,32 @@ def test_trainer_can_run_with_lr_scheduler(self):
)
trainer.train()

def test_trainer_sends_metric_to_lr_scheduler(self):
from allennlp.training.learning_rate_schedulers import ReduceOnPlateauLearningRateScheduler

class RecordMetricLearningRateScheduler(ReduceOnPlateauLearningRateScheduler):
def __init__(self, optimizer: Optimizer):
super(RecordMetricLearningRateScheduler, self).__init__(optimizer)
self.recordings: List[float] = []

def step(self, metric: float = None) -> None:
self.recordings.append(metric)
super().step(metric)

lr_scheduler = RecordMetricLearningRateScheduler(self.optimizer)
trainer = GradientDescentTrainer(
model=self.model,
optimizer=self.optimizer,
data_loader=self.data_loader,
learning_rate_scheduler=lr_scheduler,
validation_metric="-loss",
validation_data_loader=self.validation_data_loader,
num_epochs=2,
)
trainer.train()

assert all([value != 0 for value in lr_scheduler.recordings])

def test_trainer_can_resume_with_lr_scheduler(self):
lr_scheduler = CosineWithRestarts(self.optimizer, t_initial=5)
trainer = GradientDescentTrainer(
Expand Down

0 comments on commit c418f84

Please sign in to comment.