diff --git a/avalanche/models/__init__.py b/avalanche/models/__init__.py index 5e9baf6c2..cf93390af 100644 --- a/avalanche/models/__init__.py +++ b/avalanche/models/__init__.py @@ -19,3 +19,4 @@ from .base_model import BaseModel from .helper_method import as_multitask from .pnn import PNN +from .generator import * diff --git a/avalanche/models/generator.py b/avalanche/models/generator.py new file mode 100644 index 000000000..6d8a4440c --- /dev/null +++ b/avalanche/models/generator.py @@ -0,0 +1,193 @@ +################################################################################ +# Copyright (c) 2021 ContinualAI. # +# Copyrights licensed under the MIT License. # +# See the accompanying LICENSE file for terms. # +# # +# Date: 03-03-2022 # +# Author: Florian Mies # +# Website: https://github.com/travela # +################################################################################ + +""" + +File to place any kind of generative models +and their respective helper functions. + +""" + +from abc import abstractmethod +from matplotlib import transforms +import torch +import torch.nn as nn +from torchvision import transforms +from avalanche.models.utils import MLP, Flatten +from avalanche.models.base_model import BaseModel + + +class Generator(BaseModel): + """ + A base abstract class for generators + """ + + @abstractmethod + def generate(self, batch_size=None, condition=None): + """ + Lets the generator sample random samples. + Output is either a single sample or, if provided, + a batch of samples of size "batch_size" + + :param batch_size: Number of samples to generate + :param condition: Possible condition for a condotional generator + (e.g. a class label) + """ + + +########################### +# VARIATIONAL AUTOENCODER # +########################### + + +class VAEMLPEncoder(nn.Module): + ''' + Encoder part of the VAE, computer the latent represenations of the input. + + :param shape: Shape of the input to the network: (channels, height, width) + :param latent_dim: Dimension of last hidden layer + ''' + + def __init__(self, shape, latent_dim=128): + super(VAEMLPEncoder, self).__init__() + flattened_size = torch.Size(shape).numel() + self.encode = nn.Sequential( + Flatten(), + nn.Linear(in_features=flattened_size, out_features=400), + nn.BatchNorm1d(400), + nn.LeakyReLU(), + MLP([400, latent_dim]) + ) + + def forward(self, x, y=None): + x = self.encode(x) + return x + + +class VAEMLPDecoder(nn.Module): + ''' + Decoder part of the VAE. Reverses Encoder. + + :param shape: Shape of output: (channels, height, width). + :param nhid: Dimension of input. + ''' + + def __init__(self, shape, nhid=16): + super(VAEMLPDecoder, self).__init__() + flattened_size = torch.Size(shape).numel() + self.shape = shape + self.decode = nn.Sequential( + MLP([nhid, 64, 128, 256, flattened_size], last_activation=False), + nn.Sigmoid()) + self.invTrans = transforms.Compose([ + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + def forward(self, z, y=None): + if (y is None): + return self.invTrans(self.decode(z).view(-1, *self.shape)) + else: + return self.invTrans(self.decode(torch.cat((z, y), dim=1)) + .view(-1, *self.shape)) + + +class MlpVAE(Generator, nn.Module): + ''' + Variational autoencoder module: + fully-connected and suited for any input shape and type. + + The encoder only computes the latent represenations + and we have then two possible output heads: + One for the usual output distribution and one for classification. + The latter is an extension the conventional VAE and incorporates + a classifier into the network. + More details can be found in: https://arxiv.org/abs/1809.10635 + ''' + + def __init__(self, shape, nhid=16, n_classes=10, device="cpu"): + """ + :param shape: Shape of each input sample + :param nhid: Dimension of latent space of Encoder. + :param n_classes: Number of classes - + defines classification head's dimension + """ + super(MlpVAE, self).__init__() + self.dim = nhid + self.device = device + self.encoder = VAEMLPEncoder(shape, latent_dim=128) + self.calc_mean = MLP([128, nhid], last_activation=False) + self.calc_logvar = MLP([128, nhid], last_activation=False) + self.classification = MLP([128, n_classes], last_activation=False) + self.decoder = VAEMLPDecoder(shape, nhid) + + def get_features(self, x): + """ + Get features for encoder part given input x + """ + return self.encoder(x) + + def generate(self, batch_size=None): + """ + Generate random samples. + Output is either a single sample if batch_size=None, + else it is a batch of samples of size "batch_size". + """ + z = torch.randn((batch_size, self.dim)).to( + self.device) if batch_size else torch.randn((1, self.dim)).to( + self.device) + res = self.decoder(z) + if not batch_size: + res = res.squeeze(0) + return res + + def sampling(self, mean, logvar): + """ + VAE 'reparametrization trick' + """ + eps = torch.randn(mean.shape).to(self.device) + sigma = 0.5 * torch.exp(logvar) + return mean + eps * sigma + + def forward(self, x): + """ + Forward. + """ + represntations = self.encoder(x) + mean, logvar = self.calc_mean( + represntations), self.calc_logvar(represntations) + z = self.sampling(mean, logvar) + return self.decoder(z), mean, logvar + + +# Loss functions +BCE_loss = nn.BCELoss(reduction="sum") +MSE_loss = nn.MSELoss(reduction="sum") +CE_loss = nn.CrossEntropyLoss() + + +def VAE_loss(X, forward_output): + ''' + Loss function of a VAE using mean squared error for reconstruction loss. + This is the criterion for VAE training loop. + + :param X: Original input batch. + :param forward_output: Return value of a VAE.forward() call. + Triplet consisting of (X_hat, mean. logvar), ie. + (Reconstructed input after subsequent Encoder and Decoder, + mean of the VAE output distribution, + logvar of the VAE output distribution) + ''' + X_hat, mean, logvar = forward_output + reconstruction_loss = MSE_loss(X_hat, X) + KL_divergence = 0.5 * torch.sum(-1 - logvar + torch.exp(logvar) + mean**2) + return reconstruction_loss + KL_divergence + + +__all__ = ["MlpVAE", "VAE_loss"] diff --git a/avalanche/models/utils.py b/avalanche/models/utils.py index 0f08297e3..38fdf4b33 100644 --- a/avalanche/models/utils.py +++ b/avalanche/models/utils.py @@ -1,6 +1,7 @@ from avalanche.benchmarks.utils import AvalancheDataset from avalanche.models.dynamic_modules import MultiTaskModule, DynamicModule import torch.nn as nn +from collections import OrderedDict def avalanche_forward(model, x, task_labels): @@ -59,4 +60,46 @@ def add_hooks(self, model): ) -__all__ = ["avalanche_forward", "FeatureExtractorBackbone"] +class Flatten(nn.Module): + ''' + Simple nn.Module to flatten each tensor of a batch of tensors. + ''' + + def __init__(self): + super(Flatten, self).__init__() + + def forward(self, x): + batch_size = x.shape[0] + return x.view(batch_size, -1) + + +class MLP(nn.Module): + ''' + Simple nn.Module to create a multi-layer perceptron + with BatchNorm and ReLU activations. + + :param hidden_size: An array indicating the number of neurons in each layer. + :type hidden_size: int[] + :param last_activation: Indicates whether to add BatchNorm and ReLU + after the last layer. + :type last_activation: Boolean + ''' + + def __init__(self, hidden_size, last_activation=True): + super(MLP, self).__init__() + q = [] + for i in range(len(hidden_size)-1): + in_dim = hidden_size[i] + out_dim = hidden_size[i+1] + q.append(("Linear_%d" % i, nn.Linear(in_dim, out_dim))) + if (i < len(hidden_size)-2) or ((i == len(hidden_size) - 2) + and (last_activation)): + q.append(("BatchNorm_%d" % i, nn.BatchNorm1d(out_dim))) + q.append(("ReLU_%d" % i, nn.ReLU(inplace=True))) + self.mlp = nn.Sequential(OrderedDict(q)) + + def forward(self, x): + return self.mlp(x) + + +__all__ = ["avalanche_forward", "FeatureExtractorBackbone", "MLP", "Flatten"] diff --git a/avalanche/training/plugins/__init__.py b/avalanche/training/plugins/__init__.py index 52ca874ed..da82b7c06 100644 --- a/avalanche/training/plugins/__init__.py +++ b/avalanche/training/plugins/__init__.py @@ -13,5 +13,7 @@ from .lfl import LFLPlugin from .early_stopping import EarlyStoppingPlugin from .lr_scheduling import LRSchedulerPlugin +from .generative_replay import GenerativeReplayPlugin, \ + TrainGeneratorAfterExpPlugin from .rwalk import RWalkPlugin from .mas import MASPlugin diff --git a/avalanche/training/plugins/generative_replay.py b/avalanche/training/plugins/generative_replay.py new file mode 100644 index 000000000..0a5d12c89 --- /dev/null +++ b/avalanche/training/plugins/generative_replay.py @@ -0,0 +1,162 @@ +################################################################################ +# Copyright (c) 2021 ContinualAI. # +# Copyrights licensed under the MIT License. # +# See the accompanying LICENSE file for terms. # +# # +# Date: 05-03-2022 # +# Author: Florian Mies # +# Website: https://github.com/travela # +################################################################################ + +""" + +All plugins related to Generative Replay. + +""" + +from copy import deepcopy +from avalanche.core import SupervisedPlugin +from avalanche.training.templates.base import BaseTemplate +from avalanche.training.templates.supervised import SupervisedTemplate +import torch + + +class GenerativeReplayPlugin(SupervisedPlugin): + """ + Experience generative replay plugin. + + Updates the current mbatch of a strategy before training an experience + by sampling a generator model and concatenating the replay data to the + current batch. + + In this version of the plugin the number of replay samples is + increased with each new experience. Another way to implempent + the algorithm is by weighting the loss function and give more + importance to the replayed data as the number of experiences + increases. This will be implemented as an option for the user soon. + + :param generator_strategy: In case the plugin is applied to a non-generative + model (e.g. a simple classifier), this should contain an Avalanche strategy + for a model that implements a 'generate' method + (see avalanche.models.generator.Generator). Defaults to None. + :param untrained_solver: if True we assume this is the beginning of + a continual learning task and add replay data only from the second + experience onwards, otherwise we sample and add generative replay data + before training the first experience. Default to True. + :param replay_size: The user can specify the batch size of replays that + should be added to each data batch. By default each data batch will be + matched with replays of the same number. + :param increasing_replay_size: If set to True, each experience this will + double the amount of replay data added to each data batch. The effect + will be that the older experiences will gradually increase in importance + to the final loss. + """ + + def __init__(self, generator_strategy: "BaseTemplate" = None, + untrained_solver: bool = True, replay_size: int = None, + increasing_replay_size: bool = False): + ''' + Init. + ''' + super().__init__() + self.generator_strategy = generator_strategy + if self.generator_strategy: + self.generator = generator_strategy.model + else: + self.generator = None + self.untrained_solver = untrained_solver + self.model_is_generator = False + self.replay_size = replay_size + self.increasing_replay_size = increasing_replay_size + + def before_training(self, strategy: "SupervisedTemplate", *args, **kwargs): + """Checks whether we are using a user defined external generator + or we use the strategy's model as the generator. + If the generator is None after initialization + we assume that strategy.model is the generator. + (e.g. this would be the case when training a VAE with + generative replay)""" + if not self.generator_strategy: + self.generator_strategy = strategy + self.generator = strategy.model + self.model_is_generator = True + + def before_training_exp(self, strategy: "SupervisedTemplate", + num_workers: int = 0, shuffle: bool = True, + **kwargs): + """ + Make deep copies of generator and solver before training new experience. + """ + if self.untrained_solver: + # The solver needs to be trained before labelling generated data and + # the generator needs to be trained before we can sample. + return + self.old_generator = deepcopy(self.generator) + self.old_generator.eval() + if not self.model_is_generator: + self.old_model = deepcopy(strategy.model) + self.old_model.eval() + + def after_training_exp(self, strategy: "SupervisedTemplate", + num_workers: int = 0, shuffle: bool = True, + **kwargs): + """ + Set untrained_solver boolean to False after (the first) experience, + in order to start training with replay data from the second experience. + """ + self.untrained_solver = False + + def before_training_iteration(self, strategy: "SupervisedTemplate", + **kwargs): + """ + Generating and appending replay data to current minibatch before + each training iteration. + """ + if self.untrained_solver: + # The solver needs to be trained before labelling generated data and + # the generator needs to be trained before we can sample. + return + # determine how many replay data points to generate + if self.replay_size: + number_replays_to_generate = self.replay_size + else: + if self.increasing_replay_size: + number_replays_to_generate = len( + strategy.mbatch[0]) * ( + strategy.experience.current_experience) + else: + number_replays_to_generate = len(strategy.mbatch[0]) + # extend X with replay data + replay = self.old_generator.generate(number_replays_to_generate + ).to(strategy.device) + strategy.mbatch[0] = torch.cat([strategy.mbatch[0], replay], dim=0) + # extend y with predicted labels (or mock labels if model==generator) + if not self.model_is_generator: + with torch.no_grad(): + replay_output = self.old_model(replay).argmax(dim=-1) + else: + # Mock labels: + replay_output = torch.zeros(replay.shape[0]) + strategy.mbatch[1] = torch.cat( + [strategy.mbatch[1], replay_output.to(strategy.device)], dim=0) + # extend task id batch (we implicitley assume a task-free case) + strategy.mbatch[-1] = torch.cat([strategy.mbatch[-1], torch.ones( + replay.shape[0]).to(strategy.device) * strategy.mbatch[-1][0]], + dim=0) + + +class TrainGeneratorAfterExpPlugin(SupervisedPlugin): + """ + TrainGeneratorAfterExpPlugin makes sure that after each experience of + training the solver of a scholar model, we also train the generator on the + data of the current experience. + """ + + def after_training_exp(self, strategy: "SupervisedTemplate", **kwargs): + """ + The training method expects an Experience object + with a 'dataset' parameter. + """ + for plugin in strategy.plugins: + if type(plugin) is GenerativeReplayPlugin: + plugin.generator_strategy.train(strategy.experience) diff --git a/avalanche/training/supervised/strategy_wrappers.py b/avalanche/training/supervised/strategy_wrappers.py index 509435794..73a977430 100644 --- a/avalanche/training/supervised/strategy_wrappers.py +++ b/avalanche/training/supervised/strategy_wrappers.py @@ -19,6 +19,8 @@ SupervisedPlugin, CWRStarPlugin, ReplayPlugin, + GenerativeReplayPlugin, + TrainGeneratorAfterExpPlugin, GDumbPlugin, LwFPlugin, AGEMPlugin, @@ -31,7 +33,10 @@ LFLPlugin, MASPlugin ) +from avalanche.training.templates.base import BaseTemplate from avalanche.training.templates.supervised import SupervisedTemplate +from avalanche.models.generator import MlpVAE, VAE_loss +from avalanche.logging import InteractiveLogger class Naive(SupervisedTemplate): @@ -277,6 +282,194 @@ def __init__( ) +class GenerativeReplay(SupervisedTemplate): + """Generative Replay Strategy + + This implements Deep Generative Replay for a Scholar consisting of a Solver + and Generator as described in https://arxiv.org/abs/1705.08690. + + The model parameter should contain the solver. As an optional input + a generator can be wrapped in a trainable strategy + and passed to the generator_strategy parameter. By default a simple VAE will + be used as generator. + + For the case where the Generator is the model itself that is to be trained, + please simply add the GenerativeReplayPlugin() when instantiating + your Generator's strategy. + + See GenerativeReplayPlugin for more details. + This strategy does not use task identities. + """ + + def __init__( + self, + model: Module, + optimizer: Optimizer, + criterion=CrossEntropyLoss(), + train_mb_size: int = 1, + train_epochs: int = 1, + eval_mb_size: int = None, + device=None, + plugins: Optional[List[SupervisedPlugin]] = None, + evaluator: EvaluationPlugin = default_evaluator, + eval_every=-1, + generator_strategy: BaseTemplate = None, + replay_size: int = None, + increasing_replay_size: bool = False, + **base_kwargs + ): + """ + Creates an instance of Generative Replay Strategy + for a solver-generator pair. + + :param model: The solver model. + :param optimizer: The optimizer to use. + :param criterion: The loss criterion to use. + :param train_mb_size: The train minibatch size. Defaults to 1. + :param train_epochs: The number of training epochs. Defaults to 1. + :param eval_mb_size: The eval minibatch size. Defaults to 1. + :param device: The device to use. Defaults to None (cpu). + :param plugins: Plugins to be added. Defaults to None. + :param evaluator: (optional) instance of EvaluationPlugin for logging + and metric computations. + :param eval_every: the frequency of the calls to `eval` inside the + training loop. -1 disables the evaluation. 0 means `eval` is called + only at the end of the learning experience. Values >0 mean that + `eval` is called every `eval_every` epochs and at the end of the + learning experience. + :param generator_strategy: A trainable strategy with a generative model, + which employs GenerativeReplayPlugin. Defaults to None. + :param **base_kwargs: any additional + :class:`~avalanche.training.BaseTemplate` constructor arguments. + """ + + # Check if user inputs a generator model + # (which is wrapped in a strategy that can be trained and + # uses the GenerativeReplayPlugin; + # see 'VAETraining" as an example below.) + if generator_strategy is not None: + self.generator_strategy = generator_strategy + else: + # By default we use a fully-connected VAE as the generator. + # model: + generator = MlpVAE((1, 28, 28), nhid=2, device=device) + # optimzer: + lr = 0.01 + from torch.optim import Adam + optimizer_generator = Adam(filter( + lambda p: p.requires_grad, generator.parameters()), lr=lr, + weight_decay=0.0001) + # strategy (with plugin): + self.generator_strategy = VAETraining( + model=generator, + optimizer=optimizer_generator, + criterion=VAE_loss, train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, device=device, + plugins=[GenerativeReplayPlugin( + replay_size=replay_size, + increasing_replay_size=increasing_replay_size)]) + + rp = GenerativeReplayPlugin( + generator_strategy=self.generator_strategy, + replay_size=replay_size, + increasing_replay_size=increasing_replay_size) + + tgp = TrainGeneratorAfterExpPlugin() + + if plugins is None: + plugins = [tgp, rp] + else: + plugins.append(tgp) + plugins.append(rp) + + super().__init__( + model, + optimizer, + criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + **base_kwargs + ) + + +class VAETraining(SupervisedTemplate): + """VAETraining class + + This is the training strategy for the VAE model + found in the models directory. + We make use of the SupervisedTemplate, even though technically this is not a + supervised training. However, this reduces the modification to a minimum. + + We only need to overwrite the criterion function in order to pass all + necessary variables to the VAE loss function. + Furthermore we remove all metrics from the evaluator. + """ + + def __init__( + self, + model: Module, + optimizer: Optimizer, + criterion=VAE_loss, + train_mb_size: int = 1, + train_epochs: int = 1, + eval_mb_size: int = None, + device=None, + plugins: Optional[List[SupervisedPlugin]] = None, + evaluator: EvaluationPlugin = EvaluationPlugin( + loggers=[InteractiveLogger()], + suppress_warnings=True, + ), + eval_every=-1, + **base_kwargs + ): + """ + Creates an instance of the Naive strategy. + + :param model: The model. + :param optimizer: The optimizer to use. + :param criterion: The loss criterion to use. + :param train_mb_size: The train minibatch size. Defaults to 1. + :param train_epochs: The number of training epochs. Defaults to 1. + :param eval_mb_size: The eval minibatch size. Defaults to 1. + :param device: The device to use. Defaults to None (cpu). + :param plugins: Plugins to be added. Defaults to None. + :param evaluator: (optional) instance of EvaluationPlugin for logging + and metric computations. + :param eval_every: the frequency of the calls to `eval` inside the + training loop. -1 disables the evaluation. 0 means `eval` is called + only at the end of the learning experience. Values >0 mean that + `eval` is called every `eval_every` epochs and at the end of the + learning experience. + :param **base_kwargs: any additional + :class:`~avalanche.training.BaseTemplate` constructor arguments. + """ + + super().__init__( + model, + optimizer, + criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + **base_kwargs + ) + + def criterion(self): + """Adapt input to criterion as needed to compute reconstruction loss + and KL divergence. See default criterion VAELoss.""" + return self._criterion(self.mb_x, self.mb_output) + + class GSS_greedy(SupervisedTemplate): """Experience replay strategy. @@ -1000,6 +1193,8 @@ def __init__( "PNNStrategy", "CWRStar", "Replay", + "GenerativeReplay", + "VAETraining", "GDumb", "LwF", "AGEM", diff --git a/examples/generative_replay_MNIST_generator.py b/examples/generative_replay_MNIST_generator.py new file mode 100644 index 000000000..c22df5a96 --- /dev/null +++ b/examples/generative_replay_MNIST_generator.py @@ -0,0 +1,93 @@ +################################################################################ +# Copyright (c) 2021 ContinualAI. # +# Copyrights licensed under the MIT License. # +# See the accompanying LICENSE file for terms. # +# # +# Date: 01-04-2022 # +# Author(s): Florian Mies # +# E-mail: contact@continualai.org # +# Website: avalanche.continualai.org # +################################################################################ + +""" +This is a simple example on how to use the Replay strategy. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import torch +from torch.nn import CrossEntropyLoss +from torchvision import transforms +from torchvision.transforms import ToTensor, RandomCrop +import torch.optim.lr_scheduler +import matplotlib.pyplot as plt +import numpy as np +from avalanche.benchmarks import SplitMNIST +from avalanche.models import MlpVAE +from avalanche.training.supervised import VAETraining +from avalanche.training.plugins import GenerativeReplayPlugin +from avalanche.logging import InteractiveLogger + + +def main(args): + # --- CONFIG + device = torch.device( + f"cuda:{args.cuda}" + if torch.cuda.is_available() and args.cuda >= 0 + else "cpu" + ) + + # --- SCENARIO CREATION + scenario = SplitMNIST(n_experiences=10, seed=1234) + # --------- + + # MODEL CREATION + model = MlpVAE((1, 28, 28), nhid=2, device=device) + + # CREATE THE STRATEGY INSTANCE (GenerativeReplay) + cl_strategy = VAETraining( + model, + torch.optim.Adam(model.parameters(), lr=0.001), + train_mb_size=100, + train_epochs=4, + device=device, + plugins=[GenerativeReplayPlugin()] + ) + + # TRAINING LOOP + print("Starting experiment...") + f, axarr = plt.subplots(scenario.n_experiences, 10) + k = 0 + for experience in scenario.train_stream: + print("Start of experience ", + experience.current_experience) + cl_strategy.train(experience) + print("Training completed") + + samples = model.generate(10) + samples = samples.detach().cpu().numpy() + + for j in range(10): + axarr[k, j].imshow(samples[j, 0], cmap="gray") + axarr[k, 4].set_title("Generated images for experience " + str(k)) + np.vectorize(lambda ax: ax.axis('off'))(axarr) + k += 1 + + f.subplots_adjust(hspace=1.2) + plt.savefig("VAE_output_per_exp") + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--cuda", + type=int, + default=0, + help="Select zero-indexed cuda device. -1 to use CPU.", + ) + args = parser.parse_args() + main(args) diff --git a/examples/generative_replay_splitMNIST.py b/examples/generative_replay_splitMNIST.py new file mode 100644 index 000000000..05c8cf62c --- /dev/null +++ b/examples/generative_replay_splitMNIST.py @@ -0,0 +1,99 @@ +################################################################################ +# Copyright (c) 2021 ContinualAI. # +# Copyrights licensed under the MIT License. # +# See the accompanying LICENSE file for terms. # +# # +# Date: 01-04-2022 # +# Author(s): Florian Mies # +# E-mail: contact@continualai.org # +# Website: avalanche.continualai.org # +################################################################################ + +""" +This is a simple example on how to use the Replay strategy. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import argparse +import torch +from torch.nn import CrossEntropyLoss +from torchvision import transforms +from torchvision.transforms import ToTensor, RandomCrop +import torch.optim.lr_scheduler +from avalanche.benchmarks import SplitMNIST +from avalanche.models import SimpleMLP +from avalanche.training.supervised import GenerativeReplay +from avalanche.evaluation.metrics import ( + forgetting_metrics, + accuracy_metrics, + loss_metrics, +) +from avalanche.logging import InteractiveLogger +from avalanche.training.plugins import EvaluationPlugin + + +def main(args): + # --- CONFIG + device = torch.device( + f"cuda:{args.cuda}" + if torch.cuda.is_available() and args.cuda >= 0 + else "cpu" + ) + + # --- SCENARIO CREATION + scenario = SplitMNIST(n_experiences=10, seed=1234) + # --------- + + # MODEL CREATION + model = SimpleMLP(num_classes=scenario.n_classes) + + # choose some metrics and evaluation method + interactive_logger = InteractiveLogger() + + eval_plugin = EvaluationPlugin( + accuracy_metrics( + minibatch=True, epoch=True, experience=True, stream=True + ), + loss_metrics(minibatch=True, epoch=True, experience=True, stream=True), + forgetting_metrics(experience=True), + loggers=[interactive_logger], + ) + + # CREATE THE STRATEGY INSTANCE (GenerativeReplay) + cl_strategy = GenerativeReplay( + model, + torch.optim.Adam(model.parameters(), lr=0.001), + CrossEntropyLoss(), + train_mb_size=100, + train_epochs=4, + eval_mb_size=100, + device=device, + evaluator=eval_plugin, + ) + + # TRAINING LOOP + print("Starting experiment...") + results = [] + for experience in scenario.train_stream: + print("Start of experience ", experience.current_experience) + cl_strategy.train(experience) + print("Training completed") + + print("Computing accuracy on the whole test set") + results.append(cl_strategy.eval(scenario.test_stream)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--cuda", + type=int, + default=0, + help="Select zero-indexed cuda device. -1 to use CPU.", + ) + args = parser.parse_args() + main(args) diff --git a/tests/training/test_strategies_accuracy.py b/tests/training/test_strategies_accuracy.py index e06c1fbe3..c0c93681c 100644 --- a/tests/training/test_strategies_accuracy.py +++ b/tests/training/test_strategies_accuracy.py @@ -74,10 +74,10 @@ def test_multihead_cumulative(self): model, optimizer, criterion, - train_mb_size=32, + train_mb_size=64, device=get_device(), eval_mb_size=512, - train_epochs=3, + train_epochs=6, evaluator=evalp, ) benchmark = get_fast_benchmark(use_task_labels=True)