Skip to content

Commit

Permalink
Merge pull request #1513 from AlbinSou/batch_obs
Browse files Browse the repository at this point in the history
Some minor fixes in batch_observation and utils
  • Loading branch information
AntonioCarta committed Oct 10, 2023
2 parents 581df11 + 753f9f2 commit 0515a47
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from avalanche.models.utils import avalanche_model_adaptation
from avalanche.training.templates.strategy_mixin_protocol import SGDStrategyProtocol
from avalanche.models.dynamic_optimizers import reset_optimizer, update_optimizer
from avalanche.training.utils import at_task_boundary
from avalanche.training.utils import _at_task_boundary


class BatchObservation(SGDStrategyProtocol):
Expand Down Expand Up @@ -73,10 +73,7 @@ def check_model_and_optimizer(self, reset_optimizer_state=False, **kwargs):
if self.optimized_param_id is None:
self.make_optimizer(reset_optimizer_state=True, **kwargs)

if at_task_boundary(self.experience):
self.model = self.model_adaptation()
self.make_optimizer(reset_optimizer_state=reset_optimizer_state, **kwargs)
else:
if _at_task_boundary(self.experience, before=True):
self.model = self.model_adaptation()
self.make_optimizer(reset_optimizer_state=reset_optimizer_state, **kwargs)

Expand Down
23 changes: 14 additions & 9 deletions avalanche/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
"""
from collections import defaultdict
from typing import Dict, NamedTuple, List, Optional, Tuple, Callable, Union
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Module, Linear
from torch.utils.data import Dataset, DataLoader
from torch.nn import Linear, Module
from torch.utils.data import DataLoader, Dataset

from avalanche.models.batch_renorm import BatchRenorm2D
from avalanche.benchmarks import OnlineCLExperience
from avalanche.models.batch_renorm import BatchRenorm2D


def at_task_boundary(training_experience) -> bool:
def _at_task_boundary(training_experience, before=True) -> bool:
"""
Given a training experience,
returns true if the experience is at the task boundary
Expand All @@ -41,11 +41,17 @@ def at_task_boundary(training_experience) -> bool:
- If the experience is not an online experience, returns True
:param before: If used in before_training_exp,
set to True, otherwise set
to False
"""

if isinstance(training_experience, OnlineCLExperience):
if training_experience.access_task_boundaries:
if training_experience.is_first_subexp:
if before and training_experience.is_first_subexp:
return True
elif (not before) and training_experience.is_last_subexp:
return True
else:
return True
Expand Down Expand Up @@ -222,7 +228,7 @@ def replace_bn_with_brn(
):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if type(target_attr) == torch.nn.BatchNorm2d:
if isinstance(target_attr, torch.nn.BatchNorm2d):
# print('replaced: ', name, attr_str)
setattr(
m,
Expand Down Expand Up @@ -253,7 +259,7 @@ def change_brn_pars(
):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if type(target_attr) == BatchRenorm2D:
if isinstance(target_attr, BatchRenorm2D):
target_attr.momentum = torch.tensor((momentum), requires_grad=False)
target_attr.r_max = torch.tensor(r_max, requires_grad=False)
target_attr.d_max = torch.tensor(d_max, requires_grad=False)
Expand Down Expand Up @@ -481,5 +487,4 @@ def __str__(self):
"examples_per_class",
"ParamData",
"cycle",
"at_task_boundary",
]

0 comments on commit 0515a47

Please sign in to comment.