Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ER-ACE criterion simplification #1313

Merged
merged 1 commit into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 5 additions & 44 deletions avalanche/training/regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,57 +152,18 @@ class ACECriterion(RegularizationMethod):
https://openreview.net/forum?id=N8MaByOzUfb
"""

def __init__(
self,
initial_old_classes: List[int] = None,
initial_new_classes: List[int] = None,
):
"""
param: initial_old_classes: List[int]
param: initial_new_classes: List[int]
"""
self.old_classes = (
set(initial_old_classes) if
initial_old_classes is not None else set()
)
self.new_classes = (
set(initial_new_classes) if
initial_new_classes is not None else set()
)

def update(self, batch_y):
current_classes = set(torch.unique(batch_y).cpu().numpy())
inter_new = current_classes.intersection(self.new_classes)
inter_old = current_classes.intersection(self.old_classes)
if len(self.new_classes) == 0:
self.new_classes = current_classes
elif len(inter_new) == 0:
# Intersection is null, new task has arrived
self.old_classes.update(self.new_classes)
self.new_classes = current_classes
elif len(inter_new) > 0 and (
len(current_classes.union(self.new_classes)) > len(self.new_classes)
):
#
self.new_classes.update(current_classes)
elif len(inter_new) > 0 and len(inter_old) > 0:
raise ValueError(
("ACECriterion strategy cannot handle mixing",
"of same classes in different tasks")
)
def __init__(self):
pass

def __call__(self, out_in, target_in, out_buffer, target_buffer):
current_classes = torch.unique(target_in)
loss_buffer = F.cross_entropy(out_buffer, target_buffer)
oh_target_in = F.one_hot(target_in, num_classes=out_in.shape[1])
oh_target_in = oh_target_in[:, list(self.new_classes)]
oh_target_in = oh_target_in[:, current_classes]
loss_current = cross_entropy_with_oh_targets(
out_in[:, list(self.new_classes)], oh_target_in
out_in[:, current_classes], oh_target_in
)
return (loss_buffer + loss_current) / 2

@property
def all_classes(self):
return self.new_classes.union(self.old_classes)


__all__ = ["RegularizationMethod", "LearningWithoutForgetting", "ACECriterion"]
2 changes: 0 additions & 2 deletions avalanche/training/supervised/er_ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def training_epoch(self, **kwargs):

def _before_training_exp(self, **kwargs):
self.storage_policy.update(self, **kwargs)
self.ace_criterion.update(torch.tensor(self.experience.dataset.targets))
# Take all classes for ER ACE loss
buffer = self.storage_policy.buffer
if len(buffer) >= self.batch_size_mem:
Expand Down Expand Up @@ -304,7 +303,6 @@ def training_epoch(self, **kwargs):
def _before_training_exp(self, **kwargs):
# Update buffer before training exp so that we have current data in
self.storage_policy.update(self, **kwargs)
self.ace_criterion.update(torch.tensor(self.experience.dataset.targets))
buffer = self.storage_policy.buffer
if len(buffer) >= self.batch_size_mem:
self.replay_loader = cycle(
Expand Down