Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added NewClassesCrossEntropy criterion and automatic criterion plugin #1514

Merged
merged 6 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions avalanche/training/losses.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import copy

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from avalanche.training.plugins import SupervisedPlugin
from torch.nn import BCELoss
import numpy as np

from avalanche.training.plugins import SupervisedPlugin
from avalanche.training.regularization import cross_entropy_with_oh_targets


class ICaRLLossPlugin(SupervisedPlugin):
Expand Down Expand Up @@ -161,4 +164,61 @@ def forward(self, features, labels=None, mask=None):
return loss


__all__ = ["ICaRLLossPlugin", "SCRLoss"]
class MaskedCrossEntropy(SupervisedPlugin):
"""
Masked Cross Entropy

This criterion can be used for instance in Class Incremental
Learning Problems when no examplars are used
(i.e LwF in Class Incremental Learning would need to use mask="new").
"""

def __init__(self, classes=None, mask="all", reduction="mean"):
"""
param: classes: Initial value for current classes
param: mask: "all" normal cross entropy, uses all the classes seen so far
"old" cross entropy only on the old classes
"new" cross entropy only on the new classes
param: reduction: "mean" or "none", average or per-sample loss
"""
super().__init__()
assert mask in ["all", "new", "old"]
if classes is not None:
self.current_classes = set(classes)
else:
self.current_classes = set()

self.old_classes = set()
self.reduction = reduction
self.mask = mask

def __call__(self, logits, targets):
oh_targets = F.one_hot(targets, num_classes=logits.shape[1])

oh_targets = oh_targets[:, self.current_mask]
logits = logits[:, self.current_mask]

return cross_entropy_with_oh_targets(
logits,
oh_targets.float(),
reduction=self.reduction,
)

@property
def current_mask(self):
if self.mask == "all":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is incorrect. You are selecting only "seen" units, which excludes future units (e.g. DER regularizes them).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to have both "seen" and "all" options (not necessarily for all losses).

return list(self.current_classes.union(self.old_classes))
if self.mask == "new":
return list(self.current_classes)
if self.mask == "old":
return list(self.old_classes)

def adaptation(self, new_classes):
self.old_classes = self.old_classes.union(self.current_classes)
self.current_classes = set(new_classes)

def before_training_exp(self, strategy, **kwargs):
self.adaptation(strategy.experience.classes_in_this_experience)


__all__ = ["ICaRLLossPlugin", "SCRLoss", "MaskedCrossEntropy"]
9 changes: 7 additions & 2 deletions avalanche/training/regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@
from avalanche.models import MultiTaskModule, avalanche_forward


def cross_entropy_with_oh_targets(outputs, targets, eps=1e-5):
def cross_entropy_with_oh_targets(outputs, targets, eps=1e-5, reduction="mean"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use pytorch cross-entropy? it is more numerically stable than this implementation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that I need to give one hot targets this renders the masking way easier and more natural to implement. The cross entropy from pytorch does not allow that I believe.

"""Calculates cross-entropy with temperature scaling,
targets can also be soft targets but they must sum to 1"""
outputs = torch.nn.functional.softmax(outputs, dim=1)
ce = -(targets * outputs.log()).sum(1)
ce = ce.mean()
if reduction == "mean":
ce = ce.mean()
elif reduction == "none":
return ce
else:
raise NotImplementedError("reduction must be mean or none")
return ce


Expand Down
3 changes: 3 additions & 0 deletions avalanche/training/templates/base_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def __init__(
self._criterion = criterion
""" Criterion. """

if criterion not in self.plugins and isinstance(criterion, BasePlugin):
self.plugins.append(criterion)

self.train_epochs: int = train_epochs
""" Number of training epochs. """

Expand Down
28 changes: 26 additions & 2 deletions tests/training/test_losses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest

import torch
from avalanche.training.losses import ICaRLLossPlugin
import torch.nn as nn

from avalanche.training.losses import ICaRLLossPlugin, MaskedCrossEntropy


class TestICaRLLossPlugin(unittest.TestCase):
Expand Down Expand Up @@ -34,5 +35,28 @@ def test_loss(self):
assert loss3 == loss1


class TestMaskedCrossEntropy(unittest.TestCase):
def test_loss(self):
cross_entropy = nn.CrossEntropyLoss()

criterion = MaskedCrossEntropy(mask="new")
criterion.adaptation([1, 2, 3, 4])
criterion.adaptation([5, 6, 7])

mb_y = torch.tensor([5, 5, 6, 7, 6])

new_pred = torch.rand(5, 8)
new_pred_new = new_pred[:, criterion.current_mask]

loss1 = criterion(new_pred, mb_y)
loss2 = cross_entropy(new_pred_new, mb_y - 5)

criterion.mask = "all"
loss3 = criterion(new_pred, mb_y)

self.assertAlmostEqual(float(loss1), float(loss2), places=5)
self.assertNotAlmostEqual(float(loss1), float(loss3), places=5)


if __name__ == "__main__":
unittest.main()
Loading