Skip to content

Commit

Permalink
Merge pull request #925 from AntonioCarta/dpmixture
Browse files Browse the repository at this point in the history
fix EWC with dynamic models
  • Loading branch information
AntonioCarta committed Mar 12, 2022
2 parents 3ab23fc + fc36b33 commit 8966e00
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 13 deletions.
12 changes: 7 additions & 5 deletions avalanche/evaluation/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,14 @@ def phase_and_task(strategy: "SupervisedTemplate") -> Tuple[str, int]:
:return: The current phase name as either "Train" or "Task" and the
associated task label.
"""

task = strategy.experience.task_labels
if len(task) > 1:
task = None # task labels per patterns
if hasattr(strategy.experience, 'task_labels'):
task = strategy.experience.task_labels
if len(task) > 1:
task = None # task labels per patterns
else:
task = task[0]
else:
task = task[0]
task = None

if strategy.is_eval:
return EVAL, task
Expand Down
6 changes: 5 additions & 1 deletion avalanche/evaluation/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,11 @@ def result(self, strategy=None) -> float:

def update(self, strategy):
# task labels defined for each experience
task_labels = strategy.experience.task_labels
if hasattr(strategy.experience, 'task_labels'):
task_labels = strategy.experience.task_labels
else:
task_labels = [0] # add fixed task label if not available.

if len(task_labels) > 1:
# task labels defined for each pattern
task_labels = strategy.mb_task_id
Expand Down
6 changes: 5 additions & 1 deletion avalanche/evaluation/metrics/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ def result(self, strategy=None) -> float:

def update(self, strategy):
# task labels defined for each experience
task_labels = strategy.experience.task_labels
if hasattr(strategy.experience, 'task_labels'):
task_labels = strategy.experience.task_labels
else:
task_labels = [0] # add fixed task label if not available.

if len(task_labels) > 1:
# task labels defined for each pattern
# fall back to single task case
Expand Down
8 changes: 8 additions & 0 deletions avalanche/training/plugins/ewc.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def before_backward(self, strategy, **kwargs):
self.saved_params[experience],
self.importances[experience],
):
# dynamic models may add new units
# new units are ignored by the regularization
n_units = saved_param.shape[0]
cur_param = saved_param[:n_units]
penalty += (imp * (cur_param - saved_param).pow(2)).sum()
elif self.mode == "online":
prev_exp = exp_counter - 1
Expand All @@ -93,6 +97,10 @@ def before_backward(self, strategy, **kwargs):
self.saved_params[prev_exp],
self.importances[prev_exp],
):
# dynamic models may add new units
# new units are ignored by the regularization
n_units = saved_param.shape[0]
cur_param = saved_param[:n_units]
penalty += (imp * (cur_param - saved_param).pow(2)).sum()
else:
raise ValueError("Wrong EWC mode.")
Expand Down
5 changes: 5 additions & 0 deletions avalanche/training/templates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def __init__(
self.current_eval_stream = None
""" Current evaluation stream. """

@property
def is_eval(self):
"""True if the strategy is in evaluation mode."""
return not self.is_training

def train(
self,
experiences: Union[Experience, Sequence[Experience]],
Expand Down
5 changes: 4 additions & 1 deletion avalanche/training/templates/base_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,13 @@ def _train_exp(self, experience: Experience, eval_streams=None, **kwargs):
self.training_epoch(**kwargs)
self._after_training_epoch(**kwargs)

def _eval_exp(self, **kwargs):
def _before_eval_exp(self, **kwargs):
self.make_eval_dataloader(**kwargs)
# Model Adaptation (e.g. freeze/add new units)
self.model = self.model_adaptation()
super()._before_eval_exp(**kwargs)

def _eval_exp(self, **kwargs):
self.eval_epoch(**kwargs)

def make_train_dataloader(self, **kwargs):
Expand Down
5 changes: 0 additions & 5 deletions avalanche/training/templates/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,6 @@ def __init__(
use :attr:`.BaseTemplate.experience`.
"""

@property
def is_eval(self):
"""True if the strategy is in evaluation mode."""
return not self.is_training

@property
def mb_x(self):
"""Current mini-batch input."""
Expand Down

0 comments on commit 8966e00

Please sign in to comment.