Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Consistent metric tracker #4928

Merged
merged 4 commits into from
Jan 25, 2021
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 11 additions & 9 deletions allennlp/training/metric_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def __init__(
metric_name: Union[str, List[str]],
patience: Optional[int] = None,
) -> None:
self._best_so_far: Optional[float] = None
self._patience = patience
self._best_so_far: Optional[float] = None
self._epochs_with_no_improvement = 0
self._is_best_so_far = True
self.best_epoch_metrics: Dict[str, float] = {}
self._epoch_number = 0
self.best_epoch: Optional[int] = None
self.best_epoch_metrics: Dict[str, float] = {}

if isinstance(metric_name, str):
metric_name = [metric_name]
Expand All @@ -59,32 +59,34 @@ def clear(self) -> None:
self._is_best_so_far = True
self._epoch_number = 0
self.best_epoch = None
self.best_epoch_metrics.clear()

def state_dict(self) -> Dict[str, Any]:
"""
A `Trainer` can use this to serialize the state of the metric tracker.
"""
return {
"best_so_far": self._best_so_far,
"patience": self._patience,
"epochs_with_no_improvement": self._epochs_with_no_improvement,
"is_best_so_far": self._is_best_so_far,
"best_epoch_metrics": self.best_epoch_metrics,
"epoch_number": self._epoch_number,
"best_epoch": self.best_epoch,
"best_epoch_metrics": self.best_epoch_metrics,
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
A `Trainer` can use this to hydrate a metric tracker from a serialized state.
"""
self._best_so_far = state_dict["best_so_far"]
self._patience = state_dict["patience"]
self._epochs_with_no_improvement = state_dict["epochs_with_no_improvement"]
self._is_best_so_far = state_dict["is_best_so_far"]
self.best_epoch_metrics = state_dict["best_epoch_metrics"]
self._epoch_number = state_dict["epoch_number"]
self.best_epoch = state_dict["best_epoch"]

# Even though we don't promise backwards compatibility for the --recover flag,
# it's particularly easy and harmless to provide it here, so we do it.
self.best_epoch_metrics = state_dict.get("best_epoch_metrics", {})

def add_metrics(self, metrics: Dict[str, float]) -> None:
"""
Expand All @@ -103,13 +105,13 @@ def add_metrics(self, metrics: Dict[str, float]) -> None:
new_best = (self._best_so_far is None) or (combined_score > self._best_so_far)

if new_best:
self.best_epoch = self._epoch_number
self._is_best_so_far = True
self._best_so_far = combined_score
self._epochs_with_no_improvement = 0
self._is_best_so_far = True
self.best_epoch = self._epoch_number
else:
self._is_best_so_far = False
self._epochs_with_no_improvement += 1
self._is_best_so_far = False
self._epoch_number += 1

def is_best_so_far(self) -> bool:
Expand Down