From bcd6265bf589df70a2bcdceb92ea7c7ca83d00bf Mon Sep 17 00:00:00 2001 From: AlbinSou Date: Tue, 31 Oct 2023 12:29:26 +0100 Subject: [PATCH 1/4] added a few components used by CIL exemplar free methods --- avalanche/models/cosine_layer.py | 107 +++++++++ avalanche/models/fecam.py | 240 +++++++++++++++++++++ avalanche/training/plugins/update_fecam.py | 179 +++++++++++++++ avalanche/training/plugins/update_ncm.py | 119 ++++++++++ 4 files changed, 645 insertions(+) create mode 100644 avalanche/models/cosine_layer.py create mode 100644 avalanche/models/fecam.py create mode 100644 avalanche/training/plugins/update_fecam.py create mode 100644 avalanche/training/plugins/update_ncm.py diff --git a/avalanche/models/cosine_layer.py b/avalanche/models/cosine_layer.py new file mode 100644 index 000000000..c8b0939e9 --- /dev/null +++ b/avalanche/models/cosine_layer.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from avalanche.models import DynamicModule + + +""" +Implementation of Cosine layer taken and modified from https://github.com/G-U-N/PyCIL +""" + + +class CosineLinear(nn.Module): + def __init__(self, in_features, out_features, sigma=True): + super(CosineLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features)) + if sigma: + self.sigma = nn.Parameter(torch.Tensor(1)) + else: + self.register_parameter("sigma", None) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + if self.sigma is not None: + self.sigma.data.fill_(1) + + def forward(self, input): + out = F.linear( + F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1) + ) + if self.sigma is not None: + out = self.sigma * out + + return out + + +class SplitCosineLinear(nn.Module): + """ + This class keeps two Cosine Linear layers, without sigma, and handles the sigma parameter + that is common for the two of them. One CosineLinear is for the old classes and the other + one is for the new classes + """ + + def __init__(self, in_features, out_features1, out_features2, sigma=True): + super(SplitCosineLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features1 + out_features2 + self.fc1 = CosineLinear(in_features, out_features1, False) + self.fc2 = CosineLinear(in_features, out_features2, False) + if sigma: + self.sigma = nn.Parameter(torch.Tensor(1)) + self.sigma.data.fill_(1) + else: + self.register_parameter("sigma", None) + + def forward(self, x): + out1 = self.fc1(x) + out2 = self.fc2(x) + + out = torch.cat((out1, out2), dim=1) + + if self.sigma is not None: + out = self.sigma * out + + return out + + +class CosineIncrementalClassifier(DynamicModule): + # WARNING Maybe does not work with initial evaluation + def __init__(self, in_features, num_classes): + super().__init__() + self.fc = CosineLinear(in_features, num_classes) + self.num_current_classes = num_classes + self.feature_dim = in_features + + def adaptation(self, experience): + max_class = torch.max(experience.classes_in_this_experience)[0] + if max_class <= self.num_current_classes: + # Do not adapt + return + self.num_current_classes = max_class + fc = self.generate_fc(self.feature_dim, max_class + 1) + if experience.current_experience == 1: + fc.fc1.weight.data = self.fc.weight.data + fc.sigma.data = self.fc.sigma.data + else: + prev_out_features1 = self.fc.fc1.out_features + fc.fc1.weight.data[:prev_out_features1] = self.fc.fc1.weight.data + fc.fc1.weight.data[prev_out_features1:] = self.fc.fc2.weight.data + fc.sigma.data = self.fc.sigma.data + del self.fc + self.fc = fc + + def forward(self, x): + return self.fc(x) + + def generate_fc(self, in_dim, out_dim): + fc = SplitCosineLinear( + in_dim, self.fc.out_features, out_dim - self.fc.out_features + ) + return fc diff --git a/avalanche/models/fecam.py b/avalanche/models/fecam.py new file mode 100644 index 000000000..46ce9d0dc --- /dev/null +++ b/avalanche/models/fecam.py @@ -0,0 +1,240 @@ +import copy +from typing import Dict + +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +from torch import Tensor, nn + +from avalanche.benchmarks.utils import concat_datasets +from avalanche.evaluation.metric_results import MetricValue +from avalanche.models import DynamicModule +from avalanche.training.plugins import SupervisedPlugin +from avalanche.training.storage_policy import ClassBalancedBuffer +from avalanche.training.templates import SupervisedTemplate + + +class FeCAMClassifier(DynamicModule): + """ + FeCAMClassifier + + Similar to NCM but uses malahanobis distance instead of l2 distance + + This approach has been proposed for continual learning in + "FeCAM: Exploiting the Heterogeneity of Class Distributions + in Exemplar-Free Continual Learning" Goswami et. al. + (Neurips 2023) + + This requires the storage of full per-class covariance matrices + """ + + def __init__( + self, + tukey=True, + shrinkage=True, + shrink1: float = 1.0, + shrink2: float = 1.0, + tukey1: float = 0.5, + covnorm=True, + ): + """ + :param tukey: whether to use the tukey transforms + (help get the distribution closer + to multivariate gaussian) + :param shrinkage: whether to shrink the covariance matrices + :param shrink1: + :param shrink2: + :param tukey1: power in tukey transforms + :param covnorm: whether to normalize the covariance matrix + """ + super().__init__() + self.class_means_dict = {} + self.class_cov_dict = {} + + self.tukey = tukey + self.shrinkage = shrinkage + self.covnorm = covnorm + self.shrink1 = shrink1 + self.shrink2 = shrink2 + self.tukey1 = tukey1 + + self.max_class = -1 + + @torch.no_grad() + def forward(self, x): + """ + :param x: (batch_size, feature_size) + + Returns a tensor of size (batch_size, num_classes) with + negative distance of each element in the mini-batch + with respect to each class. + """ + if self.class_means_dict == {}: + self.init_missing_classes(range(self.max_class + 1), x.shape[1], x.device) + + assert self.class_means_dict != {}, "no class means available." + + if self.tukey: + x = self._tukey_transforms(x) + + maha_dist = [] + for class_id, prototype in self.class_means_dict.items(): + cov = self.class_cov_dict[class_id] + dist = self._mahalanobis(x, prototype, cov) + maha_dist.append(dist) + + # n_classes, batch_size + maha_dis = torch.stack(maha_dist).T + + # (batch_size, num_classes) + return -maha_dis + + def _mahalanobis(self, vectors, class_means, cov): + x_minus_mu = F.normalize(vectors, p=2, dim=-1) - F.normalize( + class_means, p=2, dim=-1 + ) + inv_covmat = torch.linalg.pinv(cov).float().to(vectors.device) + left_term = torch.matmul(x_minus_mu, inv_covmat) + mahal = torch.matmul(left_term, x_minus_mu.T) + return torch.diagonal(mahal, 0) + + def _tukey_transforms(self, x): + x = torch.tensor(x) + if self.tukey1 == 0: + return torch.log(x) + else: + return torch.pow(x, self.tukey1) + + def _tukey_invert_transforms(self, x): + x = torch.tensor(x) + if self.tukey1 == 0: + return torch.exp(x) + else: + return torch.pow(x, 1 / self.tukey1) + + def _shrink_cov(self, cov): + diag_mean = torch.mean(torch.diagonal(cov)) + off_diag = cov.clone() + off_diag.fill_diagonal_(0.0) + mask = off_diag != 0.0 + off_diag_mean = (off_diag * mask).sum() / mask.sum() + iden = torch.eye(cov.shape[0]).to(cov.device) + cov_ = ( + cov + + (self.shrink1 * diag_mean * iden) + + (self.shrink2 * off_diag_mean * (1 - iden)) + ) + return cov_ + + def _normalize_cov(self, cov_mat): + norm_cov_mat = {} + for key, cov in cov_mat.items(): + sd = torch.sqrt(torch.diagonal(cov)) # standard deviations of the variables + cov = cov / (torch.matmul(sd.unsqueeze(1), sd.unsqueeze(0))) + norm_cov_mat[key] = cov + + return norm_cov_mat + + def update_class_means_dict( + self, class_means_dict: Dict[int, Tensor], momentum: float = 0.5 + ): + assert momentum <= 1 and momentum >= 0 + assert isinstance(class_means_dict, dict), ( + "class_means_dict must be a dictionary mapping class_id " "to mean vector" + ) + for k, v in class_means_dict.items(): + if k not in self.class_means_dict or (self.class_means_dict[k] == 0).all(): + self.class_means_dict[k] = class_means_dict[k].clone() + else: + device = self.class_means_dict[k].device + self.class_means_dict[k] = ( + momentum * class_means_dict[k].to(device) + + (1 - momentum) * self.class_means_dict[k] + ) + + def update_class_cov_dict( + self, class_cov_dict: Dict[int, Tensor], momentum: float = 0.5 + ): + assert momentum <= 1 and momentum >= 0 + assert isinstance(class_cov_dict, dict), ( + "class_cov_dict must be a dictionary mapping class_id " "to mean vector" + ) + for k, v in class_cov_dict.items(): + if k not in self.class_cov_dict or (self.class_cov_dict[k] == 0).all(): + self.class_cov_dict[k] = class_cov_dict[k].clone() + else: + device = self.class_cov_dict[k].device + self.class_cov_dict[k] = ( + momentum * class_cov_dict[k].to(device) + + (1 - momentum) * self.class_cov_dict[k] + ) + + def replace_class_means_dict( + self, + class_means_dict: Dict[int, Tensor], + ): + self.class_means_dict = class_means_dict + + def replace_class_cov_dict( + self, + class_cov_dict: Dict[int, Tensor], + ): + self.class_cov_dict = class_cov_dict + + def init_missing_classes(self, classes, class_size, device): + for k in classes: + if k not in self.class_means_dict: + self.class_means_dict[k] = torch.zeros(class_size).to(device) + self.class_cov_dict[k] = torch.eye(class_size).to(device) + + def eval_adaptation(self, experience): + classes = experience.classes_in_this_experience + for k in classes: + self.max_class = max(k, self.max_class) + + if len(self.class_means_dict) > 0: + self.init_missing_classes( + classes, + list(self.class_means_dict.values())[0].shape[0], + list(self.class_means_dict.values())[0].device, + ) + + def apply_transforms(self, features): + if self.tukey: + features = self._tukey_transforms(features) + return features + + def apply_invert_transforms(self, features): + if self.tukey: + features = self._tukey_invert_transforms(features) + return features + + def apply_cov_transforms(self, class_cov): + if self.shrinkage: + for key, cov in class_cov.items(): + class_cov[key] = self._shrink_cov(cov) + class_cov[key] = self._shrink_cov(class_cov[key]) + if self.covnorm: + class_cov = self._normalize_cov(class_cov) + return class_cov + + +def compute_covariance(features, labels) -> Dict: + class_cov = {} + for class_id in list(torch.unique(labels).cpu().int().numpy()): + mask = labels == class_id + class_features = features[mask] + cov = torch.cov(class_features.T) + class_cov[class_id] = cov + return class_cov + + +def compute_means(features, labels) -> Dict: + class_means = {} + for class_id in list(torch.unique(labels).cpu().int().numpy()): + mask = labels == class_id + class_features = features[mask] + prototype = torch.mean(class_features, dim=0) + class_means[class_id] = prototype + return class_means diff --git a/avalanche/training/plugins/update_fecam.py b/avalanche/training/plugins/update_fecam.py new file mode 100644 index 000000000..86d53d337 --- /dev/null +++ b/avalanche/training/plugins/update_fecam.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +import copy +from typing import Dict + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from avalanche.benchmarks.utils import concat_datasets +from avalanche.training.plugins import SupervisedPlugin +from avalanche.training.templates import SupervisedTemplate +from avalanche.training.storage_policy import ClassBalancedBuffer + +from avalanche.models.fecam import compute_means, compute_covariance + +class CurrentDataFeCAMUpdate(SupervisedPlugin): + """ + Updates FeCAM cov and prototypes + using the current task data + (at the end of each task) + """ + + def __init__(self): + super().__init__() + + def after_training_exp(self, strategy, **kwargs): + assert hasattr(strategy.model, "eval_classifier") + assert isinstance(strategy.model.eval_classifier, FeCAMClassifier) + + num_workers = kwargs["num_workers"] if "num_workers" in kwargs else 0 + loader = torch.utils.data.DataLoader( + strategy.adapted_dataset.eval(), + batch_size=strategy.train_mb_size, + shuffle=False, + num_workers=num_workers, + ) + + features = [] + labels = [] + + was_training = strategy.model.training + strategy.model.eval() + + for x, y, t in loader: + x = x.to(strategy.device) + y = y.to(strategy.device) + + with torch.no_grad(): + out = strategy.model.feature_extractor(x) + + features.append(out) + labels.append(y) + + if was_training: + strategy.model.train() + + features = torch.cat(features) + labels = torch.cat(labels) + + # Transform + features = strategy.model.eval_classifier.apply_transforms(features) + class_means = compute_means(features, labels) + class_cov = compute_covariance(features, labels) + class_cov = strategy.model.eval_classifier.apply_cov_transforms(class_cov) + + strategy.model.eval_classifier.update_class_means_dict(class_means) + strategy.model.eval_classifier.update_class_cov_dict(class_cov) + +class MemoryFeCAMUpdate(SupervisedPlugin): + """ + Updates FeCAM cov and prototypes + using the current task data + (at the end of each task) + """ + + def __init__(self, mem_size=2000, storage_policy=None): + super().__init__() + if storage_policy is None: + self.storage_policy = ClassBalancedBuffer(max_size=mem_size) + else: + self.storage_policy = storage_policy + + def after_training_exp(self, strategy, **kwargs): + self.storage_policy.update(strategy) + + num_workers = kwargs["num_workers"] if "num_workers" in kwargs else 0 + loader = torch.utils.data.DataLoader( + self.storage_policy.buffer.eval(), + batch_size=strategy.train_mb_size, + shuffle=False, + num_workers=num_workers, + ) + + features = [] + labels = [] + + was_training = strategy.model.training + strategy.model.eval() + + for x, y, t in loader: + x = x.to(strategy.device) + y = y.to(strategy.device) + + with torch.no_grad(): + out = strategy.model.feature_extractor(x) + + features.append(out) + labels.append(y) + + if was_training: + strategy.model.train() + + features = torch.cat(features) + labels = torch.cat(labels) + + # Transform + features = strategy.model.eval_classifier.apply_transforms(features) + class_means = compute_means(features, labels) + class_cov = compute_covariance(features, labels) + class_cov = strategy.model.eval_classifier.apply_cov_transforms(class_cov) + + strategy.model.eval_classifier.replace_class_means_dict(class_means) + strategy.model.eval_classifier.replace_class_cov_dict(class_cov) + +class FeCAMOracle(SupervisedPlugin): + """ + Updates FeCAM cov and prototypes + using the current task data + (at the end of each task) + """ + + def __init__(self): + super().__init__() + self.all_datasets = [] + + def after_training_exp(self, strategy, **kwargs): + self.all_datasets.append(strategy.experience.dataset) + full_dataset = concat_datasets(self.all_datasets) + num_workers = kwargs["num_workers"] if "num_workers" in kwargs else 0 + loader = torch.utils.data.DataLoader( + full_dataset.eval(), + batch_size=strategy.train_mb_size, + shuffle=False, + num_workers=num_workers, + ) + + features = [] + labels = [] + + was_training = strategy.model.training + strategy.model.eval() + + for x, y, t in loader: + x = x.to(strategy.device) + y = y.to(strategy.device) + + with torch.no_grad(): + out = strategy.model.feature_extractor(x) + + features.append(out) + labels.append(y) + + if was_training: + strategy.model.train() + + features = torch.cat(features) + labels = torch.cat(labels) + + # Transform + features = strategy.model.eval_classifier.apply_transforms(features) + class_means = compute_means(features, labels) + class_cov = compute_covariance(features, labels) + class_cov = strategy.model.eval_classifier.apply_cov_transforms(class_cov) + + strategy.model.eval_classifier.replace_class_means_dict(class_means) + strategy.model.eval_classifier.replace_class_cov_dict(class_cov) + +__all__ = ["CurrentDataFeCAMUpdate", "MemoryFeCAMUpdate", "FeCAMOracle"] diff --git a/avalanche/training/plugins/update_ncm.py b/avalanche/training/plugins/update_ncm.py new file mode 100644 index 000000000..3e439f2c5 --- /dev/null +++ b/avalanche/training/plugins/update_ncm.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +import copy +from typing import Dict +import collections + +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +from torch import Tensor, nn + +from avalanche.benchmarks.utils import concat_datasets +from avalanche.training.plugins import SupervisedPlugin +from avalanche.training.storage_policy import ClassBalancedBuffer +from avalanche.training.templates import SupervisedTemplate + + +@torch.no_grad() +def compute_class_means(model, dataset, batch_size, normalize, device, **kwargs): + class_means_dict = collections.defaultdict(list()) + class_counts = collections.defaultdict(lambda: 0) + num_workers = kwargs["num_workers"] if "num_workers" in kwargs else 0 + loader = torch.utils.data.DataLoader( + dataset.eval(), batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + model.eval() + + for x, y, t in loader: + x = x.to(device) + for class_idx in torch.unique(y): + mask = y == class_idx + out = model.feature_extractor(x[mask]) + class_means_dict[int(class_idx)].append(out) + class_counts[int(class_idx)] += len(x[mask]) + + for k, v in class_means_dict.items(): + v = torch.cat(v) + if normalize: + class_means_dict[k] = ( + torch.sum(v / torch.norm(v, dim=1, keepdim=True), dim=0) + / class_counts[k] + ) + else: + class_means_dict[k] = torch.sum(v, dim=0) / class_counts[k] + + if normalize: + class_means_dict[k] = class_means_dict[k] / class_means_dict[k].norm() + + model.train() + + return class_means_dict + + +class CurrentDataNCMUpdate(SupervisedPlugin): + def __init__(self): + super().__init__() + + # Maybe change with before_eval + @torch.no_grad() + def after_training_exp(self, strategy, **kwargs): + class_means_dict = compute_class_means( + strategy.model, + strategy.adapted_dataset, + strategy.train_mb_size, + normalize=strategy.model.eval_classifier.normalize, + device=strategy.device, + ) + strategy.model.eval_classifier.update_class_means_dict(class_means_dict) + + +class MemoryNCMUpdate(SupervisedPlugin): + """ + Updates NCM prototypes + using the current task data + (at the end of each task) + """ + + def __init__(self, mem_size=2000, storage_policy=None): + super().__init__() + if storage_policy is None: + self.storage_policy = ClassBalancedBuffer(max_size=mem_size) + else: + self.storage_policy = storage_policy + + def after_training_exp(self, strategy, **kwargs): + self.storage_policy.update(strategy) + class_means_dict = compute_class_means( + strategy.model, + self.storage_policy.buffer.eval(), + batch_size=strategy.train_mb_size, + normalize=strategy.model.eval_classifier.normalize, + device=strategy.device, + ) + strategy.model.eval_classifier.replace_class_means_dict(class_means_dict) + + +class NCMOracle(SupervisedPlugin): + def __init__(self): + super().__init__() + self.all_datasets = [] + + @torch.no_grad() + def after_training_exp(self, strategy, **kwargs): + self.all_datasets.append(strategy.experience.dataset) + accumulated_dataset = concat_datasets(self.all_datasets) + + class_means_dict = compute_class_means( + strategy.model, + accumulated_dataset, + strategy.train_mb_size, + normalize=strategy.model.eval_classifier.normalize, + device=strategy.device, + ) + + strategy.model.eval_classifier.replace_class_means_dict(class_means_dict) + + +__all__ = ["CurrentDataNCMUpdate", "MemoryNCMUpdate", "NCMOracle"] From 145ba5d2691bd124ed67d4540743a76ca899c6b3 Mon Sep 17 00:00:00 2001 From: AlbinSou Date: Mon, 6 Nov 2023 14:04:59 +0100 Subject: [PATCH 2/4] added fecam tests, some docs to update functions, load_state_dict in fecam --- avalanche/models/__init__.py | 1 + avalanche/models/cosine_layer.py | 29 +++- avalanche/models/fecam.py | 61 ++++++- avalanche/training/plugins/update_fecam.py | 183 +++++++++------------ avalanche/training/plugins/update_ncm.py | 30 +++- tests/models/test_models.py | 107 +++++++++--- 6 files changed, 261 insertions(+), 150 deletions(-) diff --git a/avalanche/models/__init__.py b/avalanche/models/__init__.py index e4f6a8efc..036ced5f0 100644 --- a/avalanche/models/__init__.py +++ b/avalanche/models/__init__.py @@ -27,3 +27,4 @@ from .prompt import Prompt from .vit import create_model from .scr_model import * +from .fecam import FeCAMClassifier diff --git a/avalanche/models/cosine_layer.py b/avalanche/models/cosine_layer.py index c8b0939e9..c2a6ede2a 100644 --- a/avalanche/models/cosine_layer.py +++ b/avalanche/models/cosine_layer.py @@ -13,8 +13,24 @@ class CosineLinear(nn.Module): + """ + Cosine layer defined in + "Learning a Unified Classifier Incrementally via Rebalancing" + by Saihui Hou et al. + + Implementation modified from https://github.com/G-U-N/PyCIL + + This layer is aimed at countering the task-recency bias by removing the bias + in the classifier and normalizing the weight and the input feature before + computing the weight-feature product + """ def __init__(self, in_features, out_features, sigma=True): - super(CosineLinear, self).__init__() + """ + :param in_features: number of input features + :param out_features: number of classes + :param sigma: learnable output scaling factor + """ + super().__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features)) @@ -42,8 +58,9 @@ def forward(self, input): class SplitCosineLinear(nn.Module): """ - This class keeps two Cosine Linear layers, without sigma, and handles the sigma parameter - that is common for the two of them. One CosineLinear is for the old classes and the other + This class keeps two Cosine Linear layers, without sigma scaling, + and handles the sigma parameter that is common for the two of them. + One CosineLinear is for the old classes and the other one is for the new classes """ @@ -85,8 +102,10 @@ def adaptation(self, experience): # Do not adapt return self.num_current_classes = max_class - fc = self.generate_fc(self.feature_dim, max_class + 1) + fc = self._generate_fc(self.feature_dim, max_class + 1) if experience.current_experience == 1: + # First exp self.fc is CosineLinear + # while it is SplitCosineLinear for subsequent exps fc.fc1.weight.data = self.fc.weight.data fc.sigma.data = self.fc.sigma.data else: @@ -100,7 +119,7 @@ def adaptation(self, experience): def forward(self, x): return self.fc(x) - def generate_fc(self, in_dim, out_dim): + def _generate_fc(self, in_dim, out_dim): fc = SplitCosineLinear( in_dim, self.fc.out_features, out_dim - self.fc.out_features ) diff --git a/avalanche/models/fecam.py b/avalanche/models/fecam.py index 46ce9d0dc..9179a7485 100644 --- a/avalanche/models/fecam.py +++ b/avalanche/models/fecam.py @@ -7,12 +7,7 @@ import tqdm from torch import Tensor, nn -from avalanche.benchmarks.utils import concat_datasets -from avalanche.evaluation.metric_results import MetricValue from avalanche.models import DynamicModule -from avalanche.training.plugins import SupervisedPlugin -from avalanche.training.storage_policy import ClassBalancedBuffer -from avalanche.training.templates import SupervisedTemplate class FeCAMClassifier(DynamicModule): @@ -52,6 +47,9 @@ def __init__( self.class_means_dict = {} self.class_cov_dict = {} + self.register_buffer("class_means", None) + self.register_buffer("class_covs", None) + self.tukey = tukey self.shrinkage = shrinkage self.covnorm = covnorm @@ -127,6 +125,36 @@ def _shrink_cov(self, cov): ) return cov_ + def _vectorize_means_dict(self): + if self.class_means_dict == {}: + return + + max_class = max(self.class_means_dict.keys()) + self.max_class = max(max_class, self.max_class) + first_mean = list(self.class_means_dict.values())[0] + feature_size = first_mean.size(0) + device = first_mean.device + self.class_means = torch.zeros(self.max_class + 1, feature_size).to(device) + + for k, v in self.class_means_dict.items(): + self.class_means[k] = self.class_means_dict[k].clone() + + def _vectorize_cov_dict(self): + if self.class_cov_dict == {}: + return + + max_class = max(self.class_cov_dict.keys()) + self.max_class = max(max_class, self.max_class) + first_mean = list(self.class_cov_dict.values())[0] + feature_size = first_mean.size(0) + device = first_mean.device + self.class_covs = torch.zeros( + self.max_class + 1, feature_size, feature_size + ).to(device) + + for k, v in self.class_cov_dict.items(): + self.class_covs[k] = self.class_cov_dict[k].clone() + def _normalize_cov(self, cov_mat): norm_cov_mat = {} for key, cov in cov_mat.items(): @@ -152,6 +180,7 @@ def update_class_means_dict( momentum * class_means_dict[k].to(device) + (1 - momentum) * self.class_means_dict[k] ) + self._vectorize_means_dict() def update_class_cov_dict( self, class_cov_dict: Dict[int, Tensor], momentum: float = 0.5 @@ -169,18 +198,21 @@ def update_class_cov_dict( momentum * class_cov_dict[k].to(device) + (1 - momentum) * self.class_cov_dict[k] ) + self._vectorize_cov_dict() def replace_class_means_dict( self, class_means_dict: Dict[int, Tensor], ): self.class_means_dict = class_means_dict + self._vectorize_means_dict() def replace_class_cov_dict( self, class_cov_dict: Dict[int, Tensor], ): self.class_cov_dict = class_cov_dict + self._vectorize_cov_dict() def init_missing_classes(self, classes, class_size, device): for k in classes: @@ -219,6 +251,25 @@ def apply_cov_transforms(self, class_cov): class_cov = self._normalize_cov(class_cov) return class_cov + def load_state_dict(self, state_dict, strict: bool = True): + self.class_means = state_dict["class_means"] + self.class_covs = state_dict["class_covs"] + + super().load_state_dict(state_dict, strict) + + # fill dictionary + if self.class_means is not None: + for i in range(self.class_means.shape[0]): + if (self.class_means[i] != 0).any(): + self.class_means_dict[i] = self.class_means[i].clone() + + self.max_class = max(self.class_means_dict.keys()) + + if self.class_covs is not None: + for i in range(self.class_covs.shape[0]): + if (self.class_covs[i] != 0).any(): + self.class_cov_dict[i] = self.class_covs[i].clone() + def compute_covariance(features, labels) -> Dict: class_cov = {} diff --git a/avalanche/training/plugins/update_fecam.py b/avalanche/training/plugins/update_fecam.py index 86d53d337..8d2246b01 100644 --- a/avalanche/training/plugins/update_fecam.py +++ b/avalanche/training/plugins/update_fecam.py @@ -8,11 +8,51 @@ from torch import Tensor, nn from avalanche.benchmarks.utils import concat_datasets +from avalanche.models.fecam import compute_covariance, compute_means from avalanche.training.plugins import SupervisedPlugin -from avalanche.training.templates import SupervisedTemplate from avalanche.training.storage_policy import ClassBalancedBuffer +from avalanche.training.templates import SupervisedTemplate + + +def _gather_means_and_cov(model, dataset, batch_size, device, **kwargs): + num_workers = kwargs["num_workers"] if "num_workers" in kwargs else 0 + loader = torch.utils.data.DataLoader( + dataset.eval(), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + ) + + features = [] + labels = [] + + was_training = model.training + model.eval() + + for x, y, t in loader: + x = x.to(device) + y = y.to(device) + + with torch.no_grad(): + out = model.feature_extractor(x) + + features.append(out) + labels.append(y) + + if was_training: + model.train() + + features = torch.cat(features) + labels = torch.cat(labels) + + # Transform + features = model.eval_classifier.apply_transforms(features) + class_means = compute_means(features, labels) + class_cov = compute_covariance(features, labels) + class_cov = model.eval_classifier.apply_cov_transforms(class_cov) + + return class_means, class_cov -from avalanche.models.fecam import compute_means, compute_covariance class CurrentDataFeCAMUpdate(SupervisedPlugin): """ @@ -28,50 +68,22 @@ def after_training_exp(self, strategy, **kwargs): assert hasattr(strategy.model, "eval_classifier") assert isinstance(strategy.model.eval_classifier, FeCAMClassifier) - num_workers = kwargs["num_workers"] if "num_workers" in kwargs else 0 - loader = torch.utils.data.DataLoader( - strategy.adapted_dataset.eval(), - batch_size=strategy.train_mb_size, - shuffle=False, - num_workers=num_workers, + class_means, class_cov = _gather_means_and_cov( + strategy.model, + strategy.experience.dataset, + strategy.train_mb_size, + strategy.device, + **kwargs ) - features = [] - labels = [] - - was_training = strategy.model.training - strategy.model.eval() - - for x, y, t in loader: - x = x.to(strategy.device) - y = y.to(strategy.device) - - with torch.no_grad(): - out = strategy.model.feature_extractor(x) - - features.append(out) - labels.append(y) - - if was_training: - strategy.model.train() - - features = torch.cat(features) - labels = torch.cat(labels) - - # Transform - features = strategy.model.eval_classifier.apply_transforms(features) - class_means = compute_means(features, labels) - class_cov = compute_covariance(features, labels) - class_cov = strategy.model.eval_classifier.apply_cov_transforms(class_cov) - strategy.model.eval_classifier.update_class_means_dict(class_means) strategy.model.eval_classifier.update_class_cov_dict(class_cov) + class MemoryFeCAMUpdate(SupervisedPlugin): """ Updates FeCAM cov and prototypes - using the current task data - (at the end of each task) + using the data contained inside a memory buffer """ def __init__(self, mem_size=2000, storage_policy=None): @@ -84,50 +96,29 @@ def __init__(self, mem_size=2000, storage_policy=None): def after_training_exp(self, strategy, **kwargs): self.storage_policy.update(strategy) - num_workers = kwargs["num_workers"] if "num_workers" in kwargs else 0 - loader = torch.utils.data.DataLoader( + class_means, class_cov = _gather_means_and_cov( + strategy.model, self.storage_policy.buffer.eval(), - batch_size=strategy.train_mb_size, - shuffle=False, - num_workers=num_workers, + strategy.train_mb_size, + strategy.device, + **kwargs ) - features = [] - labels = [] - - was_training = strategy.model.training - strategy.model.eval() - - for x, y, t in loader: - x = x.to(strategy.device) - y = y.to(strategy.device) - - with torch.no_grad(): - out = strategy.model.feature_extractor(x) - - features.append(out) - labels.append(y) - - if was_training: - strategy.model.train() - - features = torch.cat(features) - labels = torch.cat(labels) - - # Transform - features = strategy.model.eval_classifier.apply_transforms(features) - class_means = compute_means(features, labels) - class_cov = compute_covariance(features, labels) - class_cov = strategy.model.eval_classifier.apply_cov_transforms(class_cov) + strategy.model.eval_classifier.update_class_means_dict(class_means) + strategy.model.eval_classifier.update_class_cov_dict(class_cov) - strategy.model.eval_classifier.replace_class_means_dict(class_means) - strategy.model.eval_classifier.replace_class_cov_dict(class_cov) class FeCAMOracle(SupervisedPlugin): """ Updates FeCAM cov and prototypes - using the current task data - (at the end of each task) + using all the data seen so far + WARNING: This is an oracle, + and thus breaks assumptions usually made + in continual learning algorithms i + (storage of full dataset) + This is meant to be used as an upper bound + for FeCAM based methods + (i.e when trying to estimate prototype and covariance drift) """ def __init__(self): @@ -137,43 +128,17 @@ def __init__(self): def after_training_exp(self, strategy, **kwargs): self.all_datasets.append(strategy.experience.dataset) full_dataset = concat_datasets(self.all_datasets) - num_workers = kwargs["num_workers"] if "num_workers" in kwargs else 0 - loader = torch.utils.data.DataLoader( - full_dataset.eval(), - batch_size=strategy.train_mb_size, - shuffle=False, - num_workers=num_workers, - ) - - features = [] - labels = [] - - was_training = strategy.model.training - strategy.model.eval() - - for x, y, t in loader: - x = x.to(strategy.device) - y = y.to(strategy.device) - - with torch.no_grad(): - out = strategy.model.feature_extractor(x) - features.append(out) - labels.append(y) - - if was_training: - strategy.model.train() - - features = torch.cat(features) - labels = torch.cat(labels) + class_means, class_cov = _gather_means_and_cov( + strategy.model, + full_dataset, + strategy.train_mb_size, + strategy.device, + **kwargs + ) - # Transform - features = strategy.model.eval_classifier.apply_transforms(features) - class_means = compute_means(features, labels) - class_cov = compute_covariance(features, labels) - class_cov = strategy.model.eval_classifier.apply_cov_transforms(class_cov) + strategy.model.eval_classifier.update_class_means_dict(class_means) + strategy.model.eval_classifier.update_class_cov_dict(class_cov) - strategy.model.eval_classifier.replace_class_means_dict(class_means) - strategy.model.eval_classifier.replace_class_cov_dict(class_cov) __all__ = ["CurrentDataFeCAMUpdate", "MemoryFeCAMUpdate", "FeCAMOracle"] diff --git a/avalanche/training/plugins/update_ncm.py b/avalanche/training/plugins/update_ncm.py index 3e439f2c5..d10befc17 100644 --- a/avalanche/training/plugins/update_ncm.py +++ b/avalanche/training/plugins/update_ncm.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 +import collections import copy from typing import Dict -import collections import numpy as np import torch @@ -13,6 +13,7 @@ from avalanche.training.plugins import SupervisedPlugin from avalanche.training.storage_policy import ClassBalancedBuffer from avalanche.training.templates import SupervisedTemplate +from avalanche.models import NCMClassifier @torch.no_grad() @@ -51,14 +52,23 @@ def compute_class_means(model, dataset, batch_size, normalize, device, **kwargs) return class_means_dict +def _check_has_ncm(model): + assert hasattr(model, "eval_classifier") + assert isinstance(model.eval_classifier, NCMClassifier) + class CurrentDataNCMUpdate(SupervisedPlugin): + """ + Updates the NCM prototypes + using the current task data + """ def __init__(self): super().__init__() # Maybe change with before_eval @torch.no_grad() def after_training_exp(self, strategy, **kwargs): + _check_has_ncm(strategy.model) class_means_dict = compute_class_means( strategy.model, strategy.adapted_dataset, @@ -72,10 +82,9 @@ def after_training_exp(self, strategy, **kwargs): class MemoryNCMUpdate(SupervisedPlugin): """ Updates NCM prototypes - using the current task data - (at the end of each task) + using the data contained inside a memory buffer + (as is is done in ICaRL) """ - def __init__(self, mem_size=2000, storage_policy=None): super().__init__() if storage_policy is None: @@ -84,6 +93,7 @@ def __init__(self, mem_size=2000, storage_policy=None): self.storage_policy = storage_policy def after_training_exp(self, strategy, **kwargs): + _check_has_ncm(strategy.model) self.storage_policy.update(strategy) class_means_dict = compute_class_means( strategy.model, @@ -96,12 +106,24 @@ def after_training_exp(self, strategy, **kwargs): class NCMOracle(SupervisedPlugin): + """ + Updates NCM prototypes + using all the data seen so far + WARNING: This is an oracle, + and thus breaks assumptions usually made + in continual learning algorithms i + (storage of full dataset) + This is meant to be used as an upper bound + for NCM based methods + (i.e when trying to estimate prototype drift) + """ def __init__(self): super().__init__() self.all_datasets = [] @torch.no_grad() def after_training_exp(self, strategy, **kwargs): + _check_has_ncm(strategy.model) self.all_datasets.append(strategy.experience.dataset) accumulated_dataset = concat_datasets(self.all_datasets) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 4972a7ff3..107c05132 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -1,43 +1,31 @@ -import sys -import os import copy +import os +import sys import tempfile - import unittest import pytorchcv.models.pyramidnet_cifar import torch +from tests.benchmarks.utils.test_avalanche_classification_dataset import \ + get_mbatch +from tests.unit_tests_utils import (common_setups, get_fast_benchmark, + load_benchmark) from torch.nn import CrossEntropyLoss from torch.optim import SGD from torch.utils.data import DataLoader from avalanche.logging import TextLogger -from avalanche.models import ( - MTSimpleMLP, - SimpleMLP, - IncrementalClassifier, - MultiHeadClassifier, - SimpleCNN, - NCMClassifier, - TrainEvalModel, - PNN, -) -from avalanche.models.dynamic_optimizers import ( - add_new_params_to_optimizer, - update_optimizer, -) +from avalanche.models import (PNN, FeCAMClassifier, IncrementalClassifier, + MTSimpleMLP, MultiHeadClassifier, NCMClassifier, + SimpleCNN, SimpleMLP, TrainEvalModel) +from avalanche.models.dynamic_optimizers import (add_new_params_to_optimizer, + update_optimizer) +from avalanche.models.pytorchcv_wrapper import (densenet, get_model, + pyramidnet, resnet, vgg) from avalanche.models.utils import avalanche_model_adaptation +from avalanche.training.checkpoint import (maybe_load_checkpoint, + save_checkpoint) from avalanche.training.supervised import Naive -from avalanche.models.pytorchcv_wrapper import ( - vgg, - resnet, - densenet, - pyramidnet, - get_model, -) -from tests.unit_tests_utils import common_setups, load_benchmark, get_fast_benchmark -from tests.benchmarks.utils.test_avalanche_classification_dataset import get_mbatch -from avalanche.training.checkpoint import save_checkpoint, maybe_load_checkpoint class PytorchcvWrapperTests(unittest.TestCase): @@ -664,6 +652,71 @@ def test_ncm_save_load(self): assert len(classifier.class_means_dict) == 2 +class FeCAMClassifierTest(unittest.TestCase): + def test_fecam_classification(self): + class_means = torch.tensor( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]], + dtype=torch.float, + ) + class_means_dict = {i: el for i, el in enumerate(class_means)} + class_cov_dict = { + i: torch.eye(el.size(0)) for i, el in class_means_dict.items() + } + + mb_x = torch.tensor( + [[4, 3, 2, 1], [3, 2, 4, 1]], + dtype=torch.float, + ) + + mb_y = torch.tensor([0, 2], dtype=torch.float) + + classifier = FeCAMClassifier() + + classifier.update_class_means_dict(class_means_dict) + classifier.update_class_cov_dict(class_cov_dict) + + pred = classifier(mb_x) + assert torch.all(torch.max(pred, 1)[1] == mb_y) + + def test_fecam_forward_without_class_means(self): + classifier = FeCAMClassifier() + classifier.init_missing_classes(list(range(10)), 7, "cpu") + logits = classifier(torch.randn(2, 7)) + assert logits.shape == (2, 10) + + def test_ncm_save_load(self): + classifier = FeCAMClassifier() + + classifier.update_class_means_dict( + { + 1: torch.randn( + 5, + ), + 2: torch.randn( + 5, + ), + } + ) + + classifier.update_class_cov_dict( + { + 1: torch.rand(5, 5), + 2: torch.rand(5, 5), + } + ) + + with tempfile.TemporaryFile() as tmpfile: + torch.save(classifier.state_dict(), tmpfile) + del classifier + classifier = FeCAMClassifier() + tmpfile.seek(0) + check = torch.load(tmpfile) + + classifier.load_state_dict(check) + + assert len(classifier.class_means_dict) == 2 + + class PNNTest(unittest.TestCase): def test_pnn_on_multiple_tasks(self): model = PNN( From 12e89c93692e99c1a1d8521f799cd7f2dd67abfa Mon Sep 17 00:00:00 2001 From: AlbinSou Date: Thu, 9 Nov 2023 14:10:54 +0100 Subject: [PATCH 3/4] added tests for cosine linear, also slightly modified it to handle random class orders --- avalanche/models/__init__.py | 1 + avalanche/models/cosine_layer.py | 90 +++++++++++++++++++----- avalanche/training/plugins/update_ncm.py | 12 ++-- tests/models/test_models.py | 83 ++++++++++++++++++---- 4 files changed, 151 insertions(+), 35 deletions(-) diff --git a/avalanche/models/__init__.py b/avalanche/models/__init__.py index 036ced5f0..a81550683 100644 --- a/avalanche/models/__init__.py +++ b/avalanche/models/__init__.py @@ -28,3 +28,4 @@ from .vit import create_model from .scr_model import * from .fecam import FeCAMClassifier +from .cosine_layer import CosineIncrementalClassifier, CosineLinear diff --git a/avalanche/models/cosine_layer.py b/avalanche/models/cosine_layer.py index c2a6ede2a..3ba08982a 100644 --- a/avalanche/models/cosine_layer.py +++ b/avalanche/models/cosine_layer.py @@ -1,8 +1,10 @@ #!/usr/bin/env python3 import math + import torch import torch.nn as nn import torch.nn.functional as F +import numpy as np from avalanche.models import DynamicModule @@ -14,16 +16,17 @@ class CosineLinear(nn.Module): """ - Cosine layer defined in - "Learning a Unified Classifier Incrementally via Rebalancing" + Cosine layer defined in + "Learning a Unified Classifier Incrementally via Rebalancing" by Saihui Hou et al. Implementation modified from https://github.com/G-U-N/PyCIL - This layer is aimed at countering the task-recency bias by removing the bias - in the classifier and normalizing the weight and the input feature before + This layer is aimed at countering the task-recency bias by removing the bias + in the classifier and normalizing the weight and the input feature before computing the weight-feature product """ + def __init__(self, in_features, out_features, sigma=True): """ :param in_features: number of input features @@ -58,8 +61,8 @@ def forward(self, input): class SplitCosineLinear(nn.Module): """ - This class keeps two Cosine Linear layers, without sigma scaling, - and handles the sigma parameter that is common for the two of them. + This class keeps two Cosine Linear layers, without sigma scaling, + and handles the sigma parameter that is common for the two of them. One CosineLinear is for the old classes and the other one is for the new classes """ @@ -89,38 +92,89 @@ def forward(self, x): class CosineIncrementalClassifier(DynamicModule): - # WARNING Maybe does not work with initial evaluation - def __init__(self, in_features, num_classes): + """ + Equivalent to IncrementalClassifier but using the cosine layer + described in "Learning a Unified Classifier Incrementally via Rebalancing" + by Saihui Hou et al. + """ + + def __init__(self, in_features, num_classes=0): + """ + :param in_features: Number of input features + :param num_classes: Number of initial classes (default=0) + If set to more than 0, the initial logits + will be mapped to the corresponding sequence of + classes starting from 0. + """ super().__init__() - self.fc = CosineLinear(in_features, num_classes) - self.num_current_classes = num_classes + self.class_order = [] + self.classes = set() + + if num_classes == 0: + self.fc = None + else: + self.fc = CosineLinear(in_features, num_classes, sigma=True) + for i in range(num_classes): + self.class_order.append(i) + self.classes = set(range(5)) + self.feature_dim = in_features def adaptation(self, experience): - max_class = torch.max(experience.classes_in_this_experience)[0] - if max_class <= self.num_current_classes: + num_classes = len(experience.classes_in_this_experience) + + new_classes = set(experience.classes_in_this_experience) - set(self.classes) + + if len(new_classes) == 0: # Do not adapt return - self.num_current_classes = max_class - fc = self._generate_fc(self.feature_dim, max_class + 1) - if experience.current_experience == 1: - # First exp self.fc is CosineLinear + + self.classes = self.classes.union(new_classes) + + for c in list(new_classes): + self.class_order.append(c) + + max_index = len(self.class_order) + + if self.fc is None: + self.fc = CosineLinear(self.feature_dim, max_index, sigma=True) + return + + fc = self._generate_fc(self.feature_dim, max_index) + + if isinstance(self.fc, CosineLinear): + # First exp self.fc is CosineLinear # while it is SplitCosineLinear for subsequent exps fc.fc1.weight.data = self.fc.weight.data fc.sigma.data = self.fc.sigma.data - else: + elif isinstance(self.fc, SplitCosineLinear): prev_out_features1 = self.fc.fc1.out_features fc.fc1.weight.data[:prev_out_features1] = self.fc.fc1.weight.data fc.fc1.weight.data[prev_out_features1:] = self.fc.fc2.weight.data fc.sigma.data = self.fc.sigma.data + del self.fc self.fc = fc def forward(self, x): - return self.fc(x) + unmapped_logits = self.fc(x) + + # Mask by default unseen classes + mapped_logits = ( + torch.ones(len(unmapped_logits), np.max(self.class_order) + 1) * -1000 + ) + mapped_logits.to(x.device) + + # Now map to classes + mapped_logits[:, self.class_order] = unmapped_logits + + return mapped_logits def _generate_fc(self, in_dim, out_dim): fc = SplitCosineLinear( in_dim, self.fc.out_features, out_dim - self.fc.out_features ) return fc + + +__all__ = ["CosineLinear", "CosineIncrementalClassifier"] diff --git a/avalanche/training/plugins/update_ncm.py b/avalanche/training/plugins/update_ncm.py index d10befc17..e062fe34a 100644 --- a/avalanche/training/plugins/update_ncm.py +++ b/avalanche/training/plugins/update_ncm.py @@ -52,6 +52,7 @@ def compute_class_means(model, dataset, batch_size, normalize, device, **kwargs) return class_means_dict + def _check_has_ncm(model): assert hasattr(model, "eval_classifier") assert isinstance(model.eval_classifier, NCMClassifier) @@ -62,6 +63,7 @@ class CurrentDataNCMUpdate(SupervisedPlugin): Updates the NCM prototypes using the current task data """ + def __init__(self): super().__init__() @@ -85,6 +87,7 @@ class MemoryNCMUpdate(SupervisedPlugin): using the data contained inside a memory buffer (as is is done in ICaRL) """ + def __init__(self, mem_size=2000, storage_policy=None): super().__init__() if storage_policy is None: @@ -109,14 +112,15 @@ class NCMOracle(SupervisedPlugin): """ Updates NCM prototypes using all the data seen so far - WARNING: This is an oracle, - and thus breaks assumptions usually made + WARNING: This is an oracle, + and thus breaks assumptions usually made in continual learning algorithms i (storage of full dataset) - This is meant to be used as an upper bound - for NCM based methods + This is meant to be used as an upper bound + for NCM based methods (i.e when trying to estimate prototype drift) """ + def __init__(self): super().__init__() self.all_datasets = [] diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 107c05132..ac745ec57 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -3,28 +3,43 @@ import sys import tempfile import unittest +import numpy as np import pytorchcv.models.pyramidnet_cifar import torch -from tests.benchmarks.utils.test_avalanche_classification_dataset import \ - get_mbatch -from tests.unit_tests_utils import (common_setups, get_fast_benchmark, - load_benchmark) +from tests.benchmarks.utils.test_avalanche_classification_dataset import get_mbatch +from tests.unit_tests_utils import common_setups, get_fast_benchmark, load_benchmark from torch.nn import CrossEntropyLoss from torch.optim import SGD from torch.utils.data import DataLoader from avalanche.logging import TextLogger -from avalanche.models import (PNN, FeCAMClassifier, IncrementalClassifier, - MTSimpleMLP, MultiHeadClassifier, NCMClassifier, - SimpleCNN, SimpleMLP, TrainEvalModel) -from avalanche.models.dynamic_optimizers import (add_new_params_to_optimizer, - update_optimizer) -from avalanche.models.pytorchcv_wrapper import (densenet, get_model, - pyramidnet, resnet, vgg) +from avalanche.models import ( + PNN, + CosineIncrementalClassifier, + FeCAMClassifier, + IncrementalClassifier, + MTSimpleMLP, + MultiHeadClassifier, + NCMClassifier, + SimpleCNN, + SimpleMLP, + TrainEvalModel, +) +from avalanche.models.cosine_layer import CosineLinear, SplitCosineLinear +from avalanche.models.dynamic_optimizers import ( + add_new_params_to_optimizer, + update_optimizer, +) +from avalanche.models.pytorchcv_wrapper import ( + densenet, + get_model, + pyramidnet, + resnet, + vgg, +) from avalanche.models.utils import avalanche_model_adaptation -from avalanche.training.checkpoint import (maybe_load_checkpoint, - save_checkpoint) +from avalanche.training.checkpoint import maybe_load_checkpoint, save_checkpoint from avalanche.training.supervised import Naive @@ -717,6 +732,48 @@ def test_ncm_save_load(self): assert len(classifier.class_means_dict) == 2 +class CosineLayerTest(unittest.TestCase): + def test_single_cosine(self): + layer = CosineLinear(32, 10) + test_input = torch.rand(5, 32) + out = layer(test_input) + out.sum().backward() + + def test_split_cosine(self): + in_feat_1, in_feat_2 = 10, 10 + layer = SplitCosineLinear(32, in_feat_1, in_feat_2) + test_input = torch.rand(5, 32) + out = layer(test_input) + self.assertEqual(out.size(1), in_feat_1 + in_feat_2) + out.sum().backward() + + def test_cosine_incremental_adaptation(self): + benchmark = load_benchmark(use_task_labels=False) + num_classes_0 = np.max(benchmark.train_stream[0].classes_in_this_experience) + 1 + num_classes_1 = np.max(benchmark.train_stream[1].classes_in_this_experience) + 1 + + test_input = torch.rand(5, 32) + + # Without initial classes + layer = CosineIncrementalClassifier(32, num_classes=0) + avalanche_model_adaptation(layer, benchmark.train_stream[0]) + out = layer(test_input) + self.assertEqual(out.size(1), num_classes_0) + avalanche_model_adaptation(layer, benchmark.train_stream[1]) + out = layer(test_input) + self.assertEqual(out.size(1), max(num_classes_0, num_classes_1)) + + # With initial classes + initial_classes = 5 + layer = CosineIncrementalClassifier(32, num_classes=initial_classes) + avalanche_model_adaptation(layer, benchmark.train_stream[0]) + out = layer(test_input) + self.assertEqual(out.size(1), max(num_classes_0, initial_classes)) + + # Test backward + out.sum().backward() + + class PNNTest(unittest.TestCase): def test_pnn_on_multiple_tasks(self): model = PNN( From ccdd1ff5939d8fe799d0d8436df8dfc03339f5f3 Mon Sep 17 00:00:00 2001 From: AlbinSou Date: Thu, 16 Nov 2023 15:25:48 +0100 Subject: [PATCH 4/4] added update utils --- avalanche/training/plugins/__init__.py | 2 + avalanche/training/plugins/update_fecam.py | 13 ++- avalanche/training/plugins/update_ncm.py | 4 +- tests/training/test_update_utils.py | 96 ++++++++++++++++++++++ 4 files changed, 111 insertions(+), 4 deletions(-) create mode 100644 tests/training/test_update_utils.py diff --git a/avalanche/training/plugins/__init__.py b/avalanche/training/plugins/__init__.py index 876e8ac71..13ee1897f 100644 --- a/avalanche/training/plugins/__init__.py +++ b/avalanche/training/plugins/__init__.py @@ -23,3 +23,5 @@ from .mir import MIRPlugin from .from_scratch_training import FromScratchTrainingPlugin from .rar import RARPlugin +from .update_ncm import * +from .update_fecam import * diff --git a/avalanche/training/plugins/update_fecam.py b/avalanche/training/plugins/update_fecam.py index 8d2246b01..a292a5525 100644 --- a/avalanche/training/plugins/update_fecam.py +++ b/avalanche/training/plugins/update_fecam.py @@ -8,6 +8,7 @@ from torch import Tensor, nn from avalanche.benchmarks.utils import concat_datasets +from avalanche.models import FeCAMClassifier from avalanche.models.fecam import compute_covariance, compute_means from avalanche.training.plugins import SupervisedPlugin from avalanche.training.storage_policy import ClassBalancedBuffer @@ -54,6 +55,11 @@ def _gather_means_and_cov(model, dataset, batch_size, device, **kwargs): return class_means, class_cov +def _check_has_fecam(model): + assert hasattr(model, "eval_classifier") + assert isinstance(model.eval_classifier, FeCAMClassifier) + + class CurrentDataFeCAMUpdate(SupervisedPlugin): """ Updates FeCAM cov and prototypes @@ -65,8 +71,7 @@ def __init__(self): super().__init__() def after_training_exp(self, strategy, **kwargs): - assert hasattr(strategy.model, "eval_classifier") - assert isinstance(strategy.model.eval_classifier, FeCAMClassifier) + _check_has_fecam(strategy.model) class_means, class_cov = _gather_means_and_cov( strategy.model, @@ -94,6 +99,8 @@ def __init__(self, mem_size=2000, storage_policy=None): self.storage_policy = storage_policy def after_training_exp(self, strategy, **kwargs): + _check_has_fecam(strategy.model) + self.storage_policy.update(strategy) class_means, class_cov = _gather_means_and_cov( @@ -126,6 +133,8 @@ def __init__(self): self.all_datasets = [] def after_training_exp(self, strategy, **kwargs): + _check_has_fecam(strategy.model) + self.all_datasets.append(strategy.experience.dataset) full_dataset = concat_datasets(self.all_datasets) diff --git a/avalanche/training/plugins/update_ncm.py b/avalanche/training/plugins/update_ncm.py index e062fe34a..22c9d9d7d 100644 --- a/avalanche/training/plugins/update_ncm.py +++ b/avalanche/training/plugins/update_ncm.py @@ -18,7 +18,7 @@ @torch.no_grad() def compute_class_means(model, dataset, batch_size, normalize, device, **kwargs): - class_means_dict = collections.defaultdict(list()) + class_means_dict = collections.defaultdict(list) class_counts = collections.defaultdict(lambda: 0) num_workers = kwargs["num_workers"] if "num_workers" in kwargs else 0 loader = torch.utils.data.DataLoader( @@ -73,7 +73,7 @@ def after_training_exp(self, strategy, **kwargs): _check_has_ncm(strategy.model) class_means_dict = compute_class_means( strategy.model, - strategy.adapted_dataset, + strategy.experience.dataset, strategy.train_mb_size, normalize=strategy.model.eval_classifier.normalize, device=strategy.device, diff --git a/tests/training/test_update_utils.py b/tests/training/test_update_utils.py new file mode 100644 index 000000000..586b41db1 --- /dev/null +++ b/tests/training/test_update_utils.py @@ -0,0 +1,96 @@ +import unittest + +import torch +import torch.nn as nn + +from avalanche.models import FeCAMClassifier, NCMClassifier, SimpleMLP, TrainEvalModel +from avalanche.training.plugins import ( + CurrentDataFeCAMUpdate, + CurrentDataNCMUpdate, + FeCAMOracle, + MemoryFeCAMUpdate, + MemoryNCMUpdate, + NCMOracle, +) +from avalanche.training.supervised import Naive +from tests.unit_tests_utils import load_benchmark + + +class UpdateNCMTest(unittest.TestCase): + def create_strategy_and_benchmark(self): + model = SimpleMLP(input_size=6) + old_layer = model.classifier + model.classifier = nn.Identity() + model = TrainEvalModel( + model, train_classifier=old_layer, eval_classifier=NCMClassifier() + ) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + strategy = Naive(model, optimizer) + benchmark = load_benchmark() + return strategy, benchmark + + def test_current_update(self): + plugin = CurrentDataNCMUpdate() + self._test_plugin(plugin) + + def test_memory_update(self): + plugin = MemoryNCMUpdate(100) + self._test_plugin(plugin) + + def test_oracle_update(self): + plugin = NCMOracle() + self._test_plugin(plugin) + + def _test_plugin(self, plugin): + strategy, benchmark = self.create_strategy_and_benchmark() + strategy.plugins.append(plugin) + strategy.experience = benchmark.train_stream[0] + test_experience = benchmark.test_stream[0] + strategy._after_training_exp() + strategy.model.eval() + + loader = iter( + torch.utils.data.DataLoader(test_experience.dataset, batch_size=10) + ) + batch_x, batch_y, batch_t = next(loader) + result = strategy.model(batch_x) + + +class UpdateFeCAMTest(unittest.TestCase): + def create_strategy_and_benchmark(self): + model = SimpleMLP(input_size=6) + old_layer = model.classifier + model.classifier = nn.Identity() + model = TrainEvalModel( + model, train_classifier=old_layer, eval_classifier=FeCAMClassifier() + ) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + strategy = Naive(model, optimizer) + benchmark = load_benchmark() + return strategy, benchmark + + def test_current_update(self): + plugin = CurrentDataFeCAMUpdate() + self._test_plugin(plugin) + + def test_memory_update(self): + plugin = MemoryFeCAMUpdate(100) + self._test_plugin(plugin) + + def test_oracle_update(self): + plugin = FeCAMOracle() + self._test_plugin(plugin) + + def _test_plugin(self, plugin): + strategy, benchmark = self.create_strategy_and_benchmark() + strategy.plugins.append(plugin) + strategy.experience = benchmark.train_stream[0] + test_experience = benchmark.test_stream[0] + strategy._after_training_exp() + strategy.model.eval() + + loader = iter( + torch.utils.data.DataLoader(test_experience.dataset, batch_size=10) + ) + batch_x, batch_y, batch_t = next(loader) + result = strategy.model(batch_x)