Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supervised Contrastive Replay implementation. #1356

Merged
merged 22 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions avalanche/models/dynamic_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,28 @@ def eval_adaptation(self, experience: CLExperience = None):
self.classifier = self.eval_classifier


class NormalizedTrainEvalModel(TrainEvalModel):
AndreaCossu marked this conversation as resolved.
Show resolved Hide resolved
"""
TrainEvalModel.
This module allows to wrap together a common feature extractor and
two classifiers: one used during training time and another
used at test time. The classifier is switched when `self.adaptation()`
is called.

After the call to the feature extractor, the output is normalized
in 2-norm along the right-most dimension.
"""
def forward(self, x):
x = self.feature_extractor(x)
x = torch.nn.functional.normalize(x, p=2, dim=-1)
return self.classifier(x)


__all__ = [
"DynamicModule",
"MultiTaskModule",
"IncrementalClassifier",
"MultiHeadClassifier",
"TrainEvalModel",
"NormalizedTrainEvalModel"
]
103 changes: 102 additions & 1 deletion avalanche/training/losses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy

import torch
from torch import nn
from avalanche.training.plugins import SupervisedPlugin
from torch.nn import BCELoss
import numpy as np
Expand Down Expand Up @@ -61,4 +62,104 @@ def after_training_exp(self, strategy, **kwargs):
).tolist()


__all__ = ["ICaRLLossPlugin"]
class SCRLoss(torch.nn.Module):
"""
Supervised Contrastive Replay Loss as defined in Eq. 5 of
https://arxiv.org/pdf/2103.13885.pdf.

Author: Yonglong Tian (yonglong@mit.edu)
Date: May 07, 2020
Original GitHub repository: https://github.com/HobbitLong/SupContrast/
LICENSE: BSD 2-Clause License
"""
def __init__(self, temperature=0.07, contrast_mode='all',
base_temperature=0.07):
super().__init__()
self.temperature = temperature
self.contrast_mode = contrast_mode
self.base_temperature = base_temperature

def forward(self, features, labels=None, mask=None):
"""Compute loss for model. If both `labels` and `mask` are None,
it degenerates to SimCLR unsupervised loss:
https://arxiv.org/pdf/2002.05709.pdf

features: [bsz, n_views, f_dim]
`n_views` is the number of crops from each image, better
be L2 normalized in f_dim dimension

Args:
features: hidden vector of shape [bsz, n_views, ...].
labels: ground truth of shape [bsz].
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
has the same class as sample i. Can be asymmetric.
Returns:
A loss scalar.
"""
device = features.device

if len(features.shape) < 3:
raise ValueError('`features` needs to be [bsz, n_views, ...],'
'at least 3 dimensions are required')
if len(features.shape) > 3:
features = features.view(features.shape[0], features.shape[1], -1)

batch_size = features.shape[0]
if labels is not None and mask is not None:
raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None:
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None:
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
raise ValueError('Num of labels does not match num of features')
mask = torch.eq(labels, labels.T).float().to(device)
else:
mask = mask.float().to(device)

contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
if self.contrast_mode == 'one':
anchor_feature = features[:, 0]
anchor_count = 1
elif self.contrast_mode == 'all':
anchor_feature = contrast_feature
anchor_count = contrast_count
else:
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),
self.temperature)

# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()

# tile mask
mask = mask.repeat(anchor_count, contrast_count)
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
0
)
mask = mask * logits_mask

# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

# loss
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()

return loss


__all__ = ["ICaRLLossPlugin", "SCRLoss"]
1 change: 1 addition & 0 deletions avalanche/training/supervised/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from .er_ace import ER_ACE, OnlineER_ACE
from .der import DER
from .l2p import LearningToPrompt
from .supervised_contrastive_replay import SCR
170 changes: 170 additions & 0 deletions avalanche/training/supervised/supervised_contrastive_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import Sequence, Optional

import torch
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.templates import SupervisedTemplate
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from avalanche.core import BaseSGDPlugin
from torchvision.transforms import Compose
from torchvision.transforms import Lambda
from avalanche.training.plugins import ReplayPlugin
from avalanche.training.losses import SCRLoss
from avalanche.training.storage_policy import ClassBalancedBuffer


class SCR(SupervisedTemplate):
"""
Supervised Contrastive Replay from https://arxiv.org/pdf/2103.13885.pdf.
This strategy trains an encoder network in a self-supervised manner to
cluster together examples of the same class while pushing away examples
of different classes. It uses the Nearest Class Mean classifier on the
embeddings produced by the encoder.

Accuracy cannot be monitored during training (no NCM classifier).
During training, NCRLoss is monitored, while during eval
CrossEntropyLoss is monitored.

"""
def __init__(self,
model: Module,
AndreaCossu marked this conversation as resolved.
Show resolved Hide resolved
optimizer: Optimizer,
augmentations: Compose = Compose([Lambda(lambda el: el)]),
mem_size: int = 100,
temperature: int = 0.1,
train_mb_size: int = 1,
train_epochs: int = 1,
eval_mb_size: Optional[int] = 1,
device="cpu",
plugins: Optional[Sequence["BaseSGDPlugin"]] = None,
evaluator=default_evaluator,
eval_every=-1,
peval_mode="epoch"):
"""
:param model: an Avalanche NormalizedTrainEvalModel, where the train
classifier uses a projection network (e.g., MLP)
while the test classifier uses a NCM Classifier.
:param optimizer: PyTorch optimizer.
:param augmentations: TorchVision Compose Transformations to augment
the input minibatch. The augmented mini-batch will be concatenated
to the original one (which includes the memory buffer).
Note: only augmentations that can be applied to Tensors
are supported.
:param mem_size: replay memory size, used also at test time to
compute class means.
:param temperature: SCR Loss temperature.
:param train_mb_size: mini-batch size for training. The default
dataloader is a task-balanced dataloader that divides each
mini-batch evenly between samples from all existing tasks in
the dataset.
:param train_epochs: number of training epochs.
:param eval_mb_size: mini-batch size for eval.
:param device: PyTorch device where the model will be allocated.
:param plugins: (optional) list of StrategyPlugins.
:param evaluator: (optional) instance of EvaluationPlugin for logging
and metric computations. None to remove logging.
:param eval_every: the frequency of the calls to `eval` inside the
training loop. -1 disables the evaluation. 0 means `eval` is called
only at the end of the learning experience. Values >0 mean that
`eval` is called every `eval_every` epochs and at the end of the
learning experience.
:param peval_mode: one of {'epoch', 'iteration'}. Decides whether the
periodic evaluation during training should execute every
`eval_every` epochs or iterations (Default='epoch').
"""

self.replay_plugin = ReplayPlugin(
mem_size,
storage_policy=ClassBalancedBuffer(max_size=mem_size))

self.augmentations = augmentations
self.temperature = temperature

self.train_loss = SCRLoss(temperature=self.temperature)
self.eval_loss = torch.nn.CrossEntropyLoss()

if plugins is None:
plugins = [self.replay_plugin]
elif isinstance(plugins, list):
plugins = [self.replay_plugin] + plugins
else:
raise ValueError("`plugins` parameter needs to be a list.")
super().__init__(
model,
optimizer,
SCRLoss(temperature=self.temperature),
train_mb_size,
train_epochs,
eval_mb_size,
device,
plugins,
evaluator,
eval_every,
peval_mode)

def criterion(self):
if self.is_training:
return self.train_loss(self.mb_output, self.mb_y)
else:
return self.eval_loss(self.mb_output, self.mb_y)

def _before_forward(self, **kwargs):
"""
Concatenate together original and augmented examples.
"""
super()._before_forward(**kwargs)
if self.is_training:
mb_x_augmented = self.augmentations(self.mbatch[0])
# (batch_size*2, input_size)
self.mbatch[0] = torch.cat([self.mbatch[0], mb_x_augmented], dim=0)

def _after_forward(self, **kwargs):
"""
Reshape the model output to have 2 views: one for original examples,
one for augmented examples.
"""
super()._after_forward(**kwargs)
if self.is_training:
assert self.mb_output.size(0) % 2 == 0
original_batch_size = int(self.mb_output.size(0) / 2)
original_examples = self.mb_output[:original_batch_size]
augmented_examples = self.mb_output[original_batch_size:]
# (original_batch_size, 2, output_size)
self.mb_output = torch.stack(
[original_examples, augmented_examples],
dim=1)

def _after_training_exp(self, **kwargs):
"""Update NCM means"""
super()._after_training_exp(**kwargs)
self.compute_class_means()

@torch.no_grad()
def compute_class_means(self):
class_means = {}

# for each class
for dataset in self.replay_plugin.storage_policy.buffer_datasets:
dl = DataLoader(dataset, shuffle=False,
batch_size=self.eval_mb_size, drop_last=False)
num_els = 0
# for each mini-batch in each class
for x, y, _ in dl:
num_els += x.size(0)
# class-balanced buffer, label is the same across mini-batch
label = y[0].item()
out = self.model.feature_extractor(x.to(self.device))
out = torch.nn.functional.normalize(out, p=2, dim=-1)
if label in class_means:
class_means[label] += out.sum(0).cpu().detach().clone()
else:
class_means[label] = out.sum(0).cpu().detach().clone()
class_means[label] /= float(num_els)

means = []
for _, v in sorted(class_means.items()):
means.append(v)
self.model.eval_classifier.class_means = torch.stack(
means,
dim=0).T.to(self.device)
71 changes: 71 additions & 0 deletions examples/supervised_contrastive_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torchvision.transforms
AndreaCossu marked this conversation as resolved.
Show resolved Hide resolved

from avalanche.training import SCR
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.optim import SGD

from avalanche.benchmarks.classic import SplitCIFAR10
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.logging import InteractiveLogger
from avalanche.models import SlimResNet18, \
NormalizedTrainEvalModel, NCMClassifier
from avalanche.training.plugins import EvaluationPlugin

fixed_class_order = np.arange(10)
device = torch.device(
f"cuda" if torch.cuda.is_available() else "cpu"
)
scenario = SplitCIFAR10(
5,
return_task_id=False,
seed=0,
fixed_class_order=fixed_class_order,
train_transform=transforms.ToTensor(),
eval_transform=transforms.ToTensor(),
shuffle=True,
class_ids_from_zero_in_each_exp=False,
)
input_size = (3, 32, 32)

nf = 20
encoding_network = SlimResNet18(10, nf=nf)
encoding_network.linear = torch.nn.Identity()
projection_network = torch.nn.Sequential(torch.nn.Linear(nf*8, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128))
model = NormalizedTrainEvalModel(
feature_extractor=encoding_network,
train_classifier=projection_network,
eval_classifier=NCMClassifier())
optimizer = SGD(model.parameters(), lr=0.1)
interactive_logger = InteractiveLogger()
loggers = [interactive_logger]
training_metrics = []
evaluation_metrics = [
accuracy_metrics(stream=True),
loss_metrics(epoch=True),
]
evaluator = EvaluationPlugin(
*training_metrics,
*evaluation_metrics,
loggers=loggers,
)

cl_strategy = SCR(
model,
optimizer,
augmentations=torchvision.transforms.Compose(
[torchvision.transforms.RandomRotation(10)]),
plugins=None,
evaluator=evaluator,
device=device,
train_mb_size=128,
eval_mb_size=64,
)
for t, experience in enumerate(scenario.train_stream):
cl_strategy.train(experience)
# cannot test on future experiences,
# since NCM has no class means for unseen classes
cl_strategy.eval(scenario.test_stream[:t+1])