Skip to content

Commit

Permalink
Merge pull request #843 from AntonioCarta/peval_iterations
Browse files Browse the repository at this point in the history
periodic eval after eval_every iteration
  • Loading branch information
AndreaCossu committed Dec 22, 2021
2 parents c3e9f33 + 572bcb5 commit b6e9239
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 73 deletions.
161 changes: 120 additions & 41 deletions avalanche/training/strategies/base_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from typing import TYPE_CHECKING

from avalanche.training.plugins import EvaluationPlugin
from avalanche.training.plugins import StrategyPlugin

if TYPE_CHECKING:
from avalanche.core import StrategyCallbacks
from avalanche.training.plugins import StrategyPlugin


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(self, model: Module, optimizer: Optimizer,
train_mb_size: int = 1, train_epochs: int = 1,
eval_mb_size: int = 1, device='cpu',
plugins: Optional[Sequence['StrategyPlugin']] = None,
evaluator=default_logger, eval_every=-1):
evaluator=default_logger, eval_every=-1, peval_mode='epoch'):
""" Init.
:param model: PyTorch model.
Expand All @@ -110,6 +110,9 @@ def __init__(self, model: Module, optimizer: Optimizer,
only at the end of the learning experience. Values >0 mean that
`eval` is called every `eval_every` epochs and at the end of the
learning experience.
:param peval_mode: one of {'epoch', 'iteration'}. Decides whether the
periodic evaluation during training should execute every
`eval_every` epochs or iterations (Default='epoch').
"""
self._criterion = criterion

Expand Down Expand Up @@ -141,15 +144,18 @@ def __init__(self, model: Module, optimizer: Optimizer,
self.evaluator = evaluator
""" EvaluationPlugin used for logging and metric computations. """

# Configure periodic evaluation.
assert peval_mode in {'epoch', 'iteration'}
self.eval_every = eval_every
peval = PeriodicEval(eval_every, peval_mode)
self.plugins.append(peval)

self.clock = Clock()
""" Incremental counters for strategy events. """
# WARNING: Clock needs to be the last plugin, otherwise
# counters will be wrong for plugins called after it.
self.plugins.append(self.clock)

self.eval_every = eval_every
""" Frequency of the evaluation during training. """

###################################################################
# State variables. These are updated during the train/eval loops. #
###################################################################
Expand Down Expand Up @@ -270,11 +276,10 @@ def train(self, experiences: Union[Experience, Sequence[Experience]],
experiences = [experiences]
if eval_streams is None:
eval_streams = [experiences]
self._eval_streams = eval_streams

self._before_training(**kwargs)

self._periodic_eval(eval_streams, do_final=False, do_initial=True)

for self.experience in experiences:
self.train_exp(self.experience, eval_streams, **kwargs)
self._after_training(**kwargs)
Expand Down Expand Up @@ -311,11 +316,6 @@ def train_exp(self, experience: Experience, eval_streams=None, **kwargs):
self.make_optimizer()

self._before_training_exp(**kwargs)

do_final = True
if self.eval_every > 0 and \
(self.train_epochs - 1) % self.eval_every == 0:
do_final = False

for _ in range(self.train_epochs):
self._before_training_epoch(**kwargs)
Expand All @@ -326,46 +326,42 @@ def train_exp(self, experience: Experience, eval_streams=None, **kwargs):

self.training_epoch(**kwargs)
self._after_training_epoch(**kwargs)
self._periodic_eval(eval_streams, do_final=False)

# Final evaluation
self._periodic_eval(eval_streams, do_final=do_final)
self._after_training_exp(**kwargs)

def _periodic_eval(self, eval_streams, do_final, do_initial=False):
""" Periodic eval controlled by `self.eval_every`. """
# Since we are switching from train to eval model inside the training
# loop, we need to save the training state, and restore it after the
# eval is done.
def _load_train_state(self, _prev_model_training_modes, _prev_state):
# restore train-state variables and training mode.
self.experience, self.adapted_dataset = _prev_state[:2]
self.dataloader = _prev_state[2]
self.is_training = _prev_state[3]
# restore each layer's training mode to original
for name, layer in self.model.named_modules():
try:
prev_mode = _prev_model_training_modes[name]
layer.train(mode=prev_mode)
except KeyError:
# Unknown parameter, probably added during the eval
# model's adaptation. We set it to train mode.
layer.train()

def _save_train_state(self):
"""Save the training state which may be modified by the eval loop.
This currently includes: experience, adapted_dataset, dataloader,
is_training, and train/eval modes for each module.
TODO: we probably need a better way to do this.
"""
_prev_state = (
self.experience,
self.adapted_dataset,
self.dataloader,
self.is_training)

# save each layer's training mode, to restore it later
_prev_model_training_modes = {}
for name, layer in self.model.named_modules():
_prev_model_training_modes[name] = layer.training

curr_epoch = self.clock.train_exp_epochs
if (self.eval_every == 0 and (do_final or do_initial)) or \
(self.eval_every > 0 and do_initial) or \
(self.eval_every > 0 and curr_epoch % self.eval_every == 0):
# in the first case we are outside epoch loop
# in the second case we are within epoch loop
for exp in eval_streams:
self.eval(exp)

# restore train-state variables and training mode.
self.experience, self.adapted_dataset = _prev_state[:2]
self.dataloader = _prev_state[2]
self.is_training = _prev_state[3]

# restore each layer's training mode to original
for name, layer in self.model.named_modules():
prev_mode = _prev_model_training_modes[name]
layer.train(mode=prev_mode)
return _prev_model_training_modes, _prev_state

def stop_training(self):
""" Signals to stop training at the next iteration. """
Expand All @@ -390,6 +386,9 @@ def eval(self,
:return: dictionary containing last recorded value for
each metric name
"""
# eval can be called inside the train method.
# Save the shared state here to restore before returning.
train_state = self._save_train_state()
self.is_training = False
self.model.eval()

Expand All @@ -413,9 +412,10 @@ def eval(self,
self._after_eval_exp(**kwargs)

self._after_eval(**kwargs)

res = self.evaluator.get_last_metrics()

# restore previous shared state.
self._load_train_state(*train_state)
return res

def _before_training_exp(self, **kwargs):
Expand Down Expand Up @@ -696,4 +696,83 @@ def _warn_for_disabled_callbacks(
)


class PeriodicEval(StrategyPlugin):
"""Schedules periodic evaluation during training.
This plugin is automatically configured and added by the BaseStrategy.
"""

def __init__(self, eval_every=-1, peval_mode='epoch', do_initial=True):
"""Init.
:param eval_every: the frequency of the calls to `eval` inside the
training loop. -1 disables the evaluation. 0 means `eval` is called
only at the end of the learning experience. Values >0 mean that
`eval` is called every `eval_every` epochs and at the end of the
learning experience.
:param peval_mode: one of {'epoch', 'iteration'}. Decides whether the
periodic evaluation during training should execute every
`eval_every` epochs or iterations (Default='epoch').
:param do_initial: whether to evaluate before each `train` call.
Occasionally needed becuase some metrics need to know the
accuracy before training.
"""
super().__init__()
assert peval_mode in {'epoch', 'iteration'}
self.eval_every = eval_every
self.peval_mode = peval_mode
self.do_initial = do_initial and eval_every > -1
self.do_final = None
self._is_eval_updated = False

def before_training(self, strategy, **kwargs):
"""Eval before each learning experience.
Occasionally needed because some metrics need the accuracy before
training.
"""
if self.do_initial:
self._peval(strategy)

def before_training_exp(self, strategy, **kwargs):
# We evaluate at the start of each experience because train_epochs
# could change.
self.do_final = True
if self.peval_mode == 'epoch':
if self.eval_every > 0 and \
(strategy.train_epochs - 1) % self.eval_every == 0:
self.do_final = False
else: # peval_mode == 'iteration'
# we may need to fix this but we don't have a way to know
# the number of total iterations.
# Right now there may be two eval calls at the last iterations.
pass
self.do_final = self.do_final and self.eval_every > -1

def after_training_exp(self, strategy, **kwargs):
"""Final eval after a learning experience."""
if self.do_final:
self._peval(strategy)

def _peval(self, strategy):
for el in strategy._eval_streams:
strategy.eval(el)

def _maybe_peval(self, strategy, counter):
if self.eval_every > 0 and counter % self.eval_every == 0:
self._peval(strategy)

def after_training_epoch(self, strategy: 'BaseStrategy', **kwargs):
"""Periodic eval controlled by `self.eval_every` and
`self.peval_mode`."""
if self.peval_mode == 'epoch':
self._maybe_peval(strategy, strategy.clock.train_exp_epochs)

def after_training_iteration(self, strategy: 'BaseStrategy', **kwargs):
"""Periodic eval controlled by `self.eval_every` and
`self.peval_mode`."""
if self.peval_mode == 'iteration':
self._maybe_peval(strategy, strategy.clock.train_exp_iterations)


__all__ = ['BaseStrategy']
Loading

0 comments on commit b6e9239

Please sign in to comment.