diff --git a/avalanche/core.py b/avalanche/core.py index 829daa7e2..ac13aac9f 100644 --- a/avalanche/core.py +++ b/avalanche/core.py @@ -193,3 +193,33 @@ def after_eval_dataset_adaptation( ) -> CallbackResult: """Called after `eval_dataset_adaptation` by the `BaseTemplate`.""" pass + + +class SupervisedMetaLearningPlugin(SupervisedPlugin[Template], ABC): + """ABC for SupervisedMetaLearningTemplate plugins. + + See `BaseTemplate` for complete description of the train/eval loop. + """ + def before_inner_updates( + self, strategy: Template, *args, **kwargs + ) -> CallbackResult: + """Called before `_inner_updates` by the `BaseTemplate`.""" + pass + + def after_inner_updates( + self, strategy: Template, *args, **kwargs + ) -> CallbackResult: + """Called before `_outer_updates` by the `BaseTemplate`.""" + pass + + def before_outer_update( + self, strategy: Template, *args, **kwargs + ) -> CallbackResult: + """Called before `_outer_updates` by the `BaseTemplate`.""" + pass + + def after_outer_update( + self, strategy: Template, *args, **kwargs + ) -> CallbackResult: + """Called before `_outer_updates` by the `BaseTemplate`.""" + pass diff --git a/avalanche/evaluation/metric_definitions.py b/avalanche/evaluation/metric_definitions.py index 7aa8782e3..ca16604eb 100644 --- a/avalanche/evaluation/metric_definitions.py +++ b/avalanche/evaluation/metric_definitions.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from .metric_results import MetricResult - from ..training.templates.supervised import SupervisedTemplate + from ..training.templates import SupervisedTemplate TResult = TypeVar("TResult") TAggregated = TypeVar("TAggregated", bound="PluginMetric") diff --git a/avalanche/evaluation/metric_utils.py b/avalanche/evaluation/metric_utils.py index 01e2a56bd..7cba57b3b 100644 --- a/avalanche/evaluation/metric_utils.py +++ b/avalanche/evaluation/metric_utils.py @@ -28,7 +28,7 @@ from torch import Tensor if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate from avalanche.benchmarks.scenarios import ClassificationExperience from avalanche.evaluation import PluginMetric diff --git a/avalanche/evaluation/metrics/checkpoint.py b/avalanche/evaluation/metrics/checkpoint.py index 51000890d..e4ec12fe1 100644 --- a/avalanche/evaluation/metrics/checkpoint.py +++ b/avalanche/evaluation/metrics/checkpoint.py @@ -19,7 +19,7 @@ from avalanche.evaluation.metric_utils import get_metric_name if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate class WeightCheckpoint(PluginMetric[Tensor]): diff --git a/avalanche/evaluation/metrics/confusion_matrix.py b/avalanche/evaluation/metrics/confusion_matrix.py index 88a3bc419..a0910104f 100644 --- a/avalanche/evaluation/metrics/confusion_matrix.py +++ b/avalanche/evaluation/metrics/confusion_matrix.py @@ -41,7 +41,7 @@ ) if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate class ConfusionMatrix(Metric[Tensor]): diff --git a/avalanche/evaluation/metrics/forgetting_bwt.py b/avalanche/evaluation/metrics/forgetting_bwt.py index d0652a3d0..2abfae036 100644 --- a/avalanche/evaluation/metrics/forgetting_bwt.py +++ b/avalanche/evaluation/metrics/forgetting_bwt.py @@ -21,7 +21,7 @@ ) if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate class Forgetting(Metric[Union[float, None, Dict[int, float]]]): diff --git a/avalanche/evaluation/metrics/forward_transfer.py b/avalanche/evaluation/metrics/forward_transfer.py index f6eb934bd..fdd41b482 100644 --- a/avalanche/evaluation/metrics/forward_transfer.py +++ b/avalanche/evaluation/metrics/forward_transfer.py @@ -21,7 +21,7 @@ ) if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate class ForwardTransfer(Metric[Union[float, None, Dict[int, float]]]): diff --git a/avalanche/evaluation/metrics/gpu_usage.py b/avalanche/evaluation/metrics/gpu_usage.py index 7ae1f4648..6304a6213 100644 --- a/avalanche/evaluation/metrics/gpu_usage.py +++ b/avalanche/evaluation/metrics/gpu_usage.py @@ -20,7 +20,7 @@ from avalanche.evaluation.metric_results import MetricResult if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate class MaxGPU(Metric[float]): diff --git a/avalanche/evaluation/metrics/images_samples.py b/avalanche/evaluation/metrics/images_samples.py index 841766737..e7ae4113e 100644 --- a/avalanche/evaluation/metrics/images_samples.py +++ b/avalanche/evaluation/metrics/images_samples.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate from avalanche.benchmarks.utils import make_classification_dataset diff --git a/avalanche/evaluation/metrics/labels_repartition.py b/avalanche/evaluation/metrics/labels_repartition.py index 2b85f855c..436e40849 100644 --- a/avalanche/evaluation/metrics/labels_repartition.py +++ b/avalanche/evaluation/metrics/labels_repartition.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate from avalanche.evaluation.metric_results import MetricResult diff --git a/avalanche/evaluation/metrics/mean_scores.py b/avalanche/evaluation/metrics/mean_scores.py index bdc422f83..d4d927891 100644 --- a/avalanche/evaluation/metrics/mean_scores.py +++ b/avalanche/evaluation/metrics/mean_scores.py @@ -32,7 +32,7 @@ from typing_extensions import Literal if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate from avalanche.evaluation.metric_results import MetricResult diff --git a/avalanche/evaluation/metrics/ram_usage.py b/avalanche/evaluation/metrics/ram_usage.py index 701ad8581..364454218 100644 --- a/avalanche/evaluation/metrics/ram_usage.py +++ b/avalanche/evaluation/metrics/ram_usage.py @@ -19,7 +19,7 @@ from avalanche.evaluation.metric_results import MetricResult if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate class MaxRAM(Metric[float]): diff --git a/avalanche/evaluation/metrics/timing.py b/avalanche/evaluation/metrics/timing.py index 2704fc4c0..eb09ca8c2 100644 --- a/avalanche/evaluation/metrics/timing.py +++ b/avalanche/evaluation/metrics/timing.py @@ -18,7 +18,7 @@ from avalanche.evaluation.metrics.mean import Mean if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate class ElapsedTime(Metric[float]): diff --git a/avalanche/logging/base_logger.py b/avalanche/logging/base_logger.py index 9e03daa87..77b86864e 100644 --- a/avalanche/logging/base_logger.py +++ b/avalanche/logging/base_logger.py @@ -4,7 +4,7 @@ if TYPE_CHECKING: from avalanche.evaluation.metric_results import MetricValue - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate class BaseLogger(ABC): diff --git a/avalanche/logging/interactive_logging.py b/avalanche/logging/interactive_logging.py index fcd253c92..a4c52a2d9 100644 --- a/avalanche/logging/interactive_logging.py +++ b/avalanche/logging/interactive_logging.py @@ -19,7 +19,7 @@ from tqdm import tqdm if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate class InteractiveLogger(TextLogger, SupervisedPlugin): @@ -61,6 +61,8 @@ def before_training_epoch( metric_values: List["MetricValue"], **kwargs ): + if isinstance(strategy.experience, OnlineCLExperience): + return super().before_training_epoch(strategy, metric_values, **kwargs) self._progress.total = len(strategy.dataloader) @@ -70,6 +72,8 @@ def after_training_epoch( metric_values: List["MetricValue"], **kwargs ): + if isinstance(strategy.experience, OnlineCLExperience): + return self._end_progress() super().after_training_epoch(strategy, metric_values, **kwargs) diff --git a/avalanche/logging/wandb_logger.py b/avalanche/logging/wandb_logger.py index dd523916d..5c4902992 100644 --- a/avalanche/logging/wandb_logger.py +++ b/avalanche/logging/wandb_logger.py @@ -34,7 +34,7 @@ if TYPE_CHECKING: from avalanche.evaluation.metric_results import MetricValue - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate class WandBLogger(BaseLogger, SupervisedPlugin): diff --git a/avalanche/training/plugins/evaluation.py b/avalanche/training/plugins/evaluation.py index 7467b9850..22a7dfda6 100644 --- a/avalanche/training/plugins/evaluation.py +++ b/avalanche/training/plugins/evaluation.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from avalanche.evaluation import PluginMetric from avalanche.logging import BaseLogger - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate class EvaluationPlugin: diff --git a/avalanche/training/plugins/gdumb.py b/avalanche/training/plugins/gdumb.py index 85c921be5..be44c8cdc 100644 --- a/avalanche/training/plugins/gdumb.py +++ b/avalanche/training/plugins/gdumb.py @@ -5,7 +5,7 @@ from avalanche.training.storage_policy import ClassBalancedBuffer if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate class GDumbPlugin(SupervisedPlugin): diff --git a/avalanche/training/plugins/gss_greedy.py b/avalanche/training/plugins/gss_greedy.py index 234f04052..12f9e7cfc 100644 --- a/avalanche/training/plugins/gss_greedy.py +++ b/avalanche/training/plugins/gss_greedy.py @@ -6,7 +6,7 @@ from avalanche.training.plugins.strategy_plugin import SupervisedPlugin if TYPE_CHECKING: - from ..templates.supervised import SupervisedTemplate + from ..templates import SupervisedTemplate class GSS_greedyPlugin(SupervisedPlugin): diff --git a/avalanche/training/plugins/lr_scheduling.py b/avalanche/training/plugins/lr_scheduling.py index 46c29ba63..e288ff915 100644 --- a/avalanche/training/plugins/lr_scheduling.py +++ b/avalanche/training/plugins/lr_scheduling.py @@ -8,7 +8,7 @@ import inspect if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate class LRSchedulerPlugin(SupervisedPlugin): diff --git a/avalanche/training/plugins/replay.py b/avalanche/training/plugins/replay.py index 141c47be6..f653a1834 100644 --- a/avalanche/training/plugins/replay.py +++ b/avalanche/training/plugins/replay.py @@ -9,7 +9,7 @@ ) if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate + from avalanche.training.templates import SupervisedTemplate class ReplayPlugin(SupervisedPlugin): diff --git a/avalanche/training/plugins/synaptic_intelligence.py b/avalanche/training/plugins/synaptic_intelligence.py index 79379fce0..6efd4da42 100644 --- a/avalanche/training/plugins/synaptic_intelligence.py +++ b/avalanche/training/plugins/synaptic_intelligence.py @@ -13,7 +13,7 @@ from avalanche.training.utils import get_layers_and_params if TYPE_CHECKING: - from ..templates.supervised import SupervisedTemplate + from ..templates import SupervisedTemplate SynDataType = Dict[str, Dict[str, Tensor]] diff --git a/avalanche/training/storage_policy.py b/avalanche/training/storage_policy.py index 0ea3d6d1b..37e222192 100644 --- a/avalanche/training/storage_policy.py +++ b/avalanche/training/storage_policy.py @@ -17,7 +17,7 @@ from ..benchmarks.utils.utils import concat_datasets if TYPE_CHECKING: - from .templates.supervised import SupervisedTemplate + from .templates import SupervisedTemplate class ExemplarsBuffer(ABC): diff --git a/avalanche/training/supervised/ar1.py b/avalanche/training/supervised/ar1.py index 918b61bee..882aa97f2 100644 --- a/avalanche/training/supervised/ar1.py +++ b/avalanche/training/supervised/ar1.py @@ -16,7 +16,7 @@ SynapticIntelligencePlugin, CWRStarPlugin, ) -from avalanche.training.templates.supervised import SupervisedTemplate +from avalanche.training.templates import SupervisedTemplate from avalanche.training.utils import ( replace_bn_with_brn, get_last_fc_layer, diff --git a/avalanche/training/supervised/cumulative.py b/avalanche/training/supervised/cumulative.py index 596c4aa1e..f2ae3981b 100644 --- a/avalanche/training/supervised/cumulative.py +++ b/avalanche/training/supervised/cumulative.py @@ -8,7 +8,7 @@ from avalanche.benchmarks.utils.utils import concat_datasets from avalanche.training.plugins.evaluation import default_evaluator from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin -from avalanche.training.templates.supervised import SupervisedTemplate +from avalanche.training.templates import SupervisedTemplate class Cumulative(SupervisedTemplate): diff --git a/avalanche/training/supervised/deep_slda.py b/avalanche/training/supervised/deep_slda.py index 06446d86d..676652636 100644 --- a/avalanche/training/supervised/deep_slda.py +++ b/avalanche/training/supervised/deep_slda.py @@ -5,7 +5,7 @@ import torch from avalanche.training.plugins import SupervisedPlugin -from avalanche.training.templates.supervised import SupervisedTemplate +from avalanche.training.templates import SupervisedTemplate from avalanche.training.plugins.evaluation import default_evaluator from avalanche.models.dynamic_modules import MultiTaskModule from avalanche.models import FeatureExtractorBackbone diff --git a/avalanche/training/supervised/icarl.py b/avalanche/training/supervised/icarl.py index ac7b9c6e7..0622f06e9 100644 --- a/avalanche/training/supervised/icarl.py +++ b/avalanche/training/supervised/icarl.py @@ -19,7 +19,7 @@ from avalanche.training.plugins.strategy_plugin import SupervisedPlugin from torch.nn import Module from torch.utils.data import DataLoader -from avalanche.training.templates.supervised import SupervisedTemplate +from avalanche.training.templates import SupervisedTemplate class ICaRL(SupervisedTemplate): diff --git a/avalanche/training/supervised/joint_training.py b/avalanche/training/supervised/joint_training.py index bc2087510..71fd0fbf2 100644 --- a/avalanche/training/supervised/joint_training.py +++ b/avalanche/training/supervised/joint_training.py @@ -18,7 +18,7 @@ from avalanche.benchmarks.utils import concat_classification_datasets from avalanche.benchmarks.utils.utils import concat_datasets from avalanche.training.plugins.evaluation import default_evaluator -from avalanche.training.templates.supervised import SupervisedTemplate +from avalanche.training.templates import SupervisedTemplate from avalanche.models import DynamicModule if TYPE_CHECKING: diff --git a/avalanche/training/supervised/lamaml.py b/avalanche/training/supervised/lamaml.py index ec27ed8fd..431f3f55c 100644 --- a/avalanche/training/supervised/lamaml.py +++ b/avalanche/training/supervised/lamaml.py @@ -5,39 +5,23 @@ import torch.nn.functional as F from torch.nn import Module, CrossEntropyLoss from torch.optim import Optimizer +import math try: import higher except ImportError: - raise ModuleNotFoundError( - "higher not found, if you want to use " - "MAML please install avalanche with " - "the extra dependencies: " - "pip install avalanche-lib[extra]" - ) -import math + raise ModuleNotFoundError("higher not found, if you want to use " + "MAML please install avalanche with " + "the extra dependencies: " + "pip install avalanche-lib[extra]") from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin from avalanche.training.plugins.evaluation import default_evaluator -from avalanche.training.templates import SupervisedTemplate +from avalanche.training.templates import SupervisedMetaLearningTemplate from avalanche.models.utils import avalanche_forward -def init_kaiming_normal(m): - if isinstance(m, nn.Conv2d): - torch.nn.init.constant_(m.weight.data, 1.0) - torch.nn.init.kaiming_normal_(m.weight.data) - if m.bias is not None: - m.bias.data.zero_() - - elif isinstance(m, nn.Linear): - torch.nn.init.constant_(m.weight.data, 1.0) - torch.nn.init.kaiming_normal_(m.weight.data) - if m.bias is not None: - m.bias.data.zero_() - - -class LaMAML(SupervisedTemplate): +class LaMAML(SupervisedMetaLearningTemplate): def __init__( self, model: Module, @@ -141,21 +125,18 @@ def _before_training_exp(self, **kwargs): self.alpha_params.parameters(), lr=self.lr_alpha ) - def training_epoch(self, **kwargs): - for self.mbatch in self.dataloader: - if self._stop_training: - break - - self._unpack_minibatch() - self._before_training_iteration(**kwargs) - self.loss = 0 - - self.train_batch() + def apply_grad(self, module, grads): + for i, p in enumerate(module.parameters()): + grad = grads[i] + if grad is None: + grad = torch.zeros(p.shape).float().to(self.device) - self.mb_output = self.forward() - self._after_training_iteration(**kwargs) + if p.grad is None: + p.grad = grad + else: + p.grad += grad - def inner_update(self, fast_model, x, y, t): + def inner_update_step(self, fast_model, x, y, t): """Update fast weights using current samples and return the updated fast model. """ @@ -192,20 +173,9 @@ def inner_update(self, fast_model, x, y, t): # Update fast model's weights fast_model.update_params(new_fast_params) - def apply_grad(self, module, grads): - for i, p in enumerate(module.parameters()): - grad = grads[i] - if grad is None: - grad = torch.zeros(p.shape).float().to(self.device) - - if p.grad is None: - p.grad = grad - else: - p.grad += grad - - def train_batch(self): + def _inner_updates(self, **kwargs): # Create a stateless copy of the model for inner-updates - fast_model = higher.patch.monkeypatch( + self.fast_model = higher.patch.monkeypatch( self.model, copy_initial_weights=True, track_higher_grads=self.second_order, @@ -219,28 +189,30 @@ def train_batch(self): bsize_data = batch_x.shape[0] rough_sz = math.ceil(bsize_data / self.n_inner_updates) - meta_losses = [0 for _ in range(self.n_inner_updates)] + self.meta_losses = [0 for _ in range(self.n_inner_updates)] for i in range(self.n_inner_updates): - batch_x_i = batch_x[i * rough_sz : (i + 1) * rough_sz] - batch_y_i = batch_y[i * rough_sz : (i + 1) * rough_sz] - batch_t_i = batch_t[i * rough_sz : (i + 1) * rough_sz] + batch_x_i = batch_x[i * rough_sz: (i + 1) * rough_sz] + batch_y_i = batch_y[i * rough_sz: (i + 1) * rough_sz] + batch_t_i = batch_t[i * rough_sz: (i + 1) * rough_sz] # We assume that samples for inner update are from the same task - self.inner_update(fast_model, batch_x_i, batch_y_i, batch_t_i) + self.inner_update_step(self.fast_model, batch_x_i, batch_y_i, + batch_t_i) # Compute meta-loss with the combination of batch and buffer samples logits_meta = avalanche_forward( - fast_model, self.mb_x, self.mb_task_id + self.fast_model, self.mb_x, self.mb_task_id ) meta_loss = self._criterion(logits_meta, self.mb_y) - meta_losses[i] = meta_loss + self.meta_losses[i] = meta_loss + def _outer_update(self, **kwargs): # Compute meta-gradient for the main model - meta_loss = sum(meta_losses) / len(meta_losses) + meta_loss = sum(self.meta_losses) / len(self.meta_losses) meta_grad_model = torch.autograd.grad( meta_loss, - fast_model.parameters(time=0), + self.fast_model.parameters(time=0), retain_graph=True, allow_unused=True, ) @@ -271,9 +243,23 @@ def train_batch(self): self.optimizer.step() else: for p, alpha in zip( - self.model.parameters(), self.alpha_params.parameters() + self.model.parameters(), self.alpha_params.parameters() ): # Use relu on updated LRs to avoid negative values p.data = p.data - p.grad * F.relu(alpha) self.loss = meta_loss + + +def init_kaiming_normal(m): + if isinstance(m, nn.Conv2d): + torch.nn.init.constant_(m.weight.data, 1.0) + torch.nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + + elif isinstance(m, nn.Linear): + torch.nn.init.constant_(m.weight.data, 1.0) + torch.nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() diff --git a/avalanche/training/supervised/strategy_wrappers.py b/avalanche/training/supervised/strategy_wrappers.py index bd01207cc..c463d0540 100644 --- a/avalanche/training/supervised/strategy_wrappers.py +++ b/avalanche/training/supervised/strategy_wrappers.py @@ -34,7 +34,7 @@ MASPlugin, ) from avalanche.training.templates.base import BaseTemplate -from avalanche.training.templates.supervised import SupervisedTemplate +from avalanche.training.templates import SupervisedTemplate from avalanche.models.generator import MlpVAE, VAE_loss from avalanche.logging import InteractiveLogger diff --git a/avalanche/training/supervised/strategy_wrappers_online.py b/avalanche/training/supervised/strategy_wrappers_online.py index 3eb5d5003..d757e2401 100644 --- a/avalanche/training/supervised/strategy_wrappers_online.py +++ b/avalanche/training/supervised/strategy_wrappers_online.py @@ -15,13 +15,13 @@ from avalanche.training.plugins.evaluation import default_evaluator from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin -from avalanche.training.templates.online_supervised import ( +from avalanche.training.templates import ( OnlineSupervisedTemplate, ) class OnlineNaive(OnlineSupervisedTemplate): - """Naive finetuning. + """Online naive finetuning. The simplest (and least effective) Continual Learning strategy. Naive just incrementally fine tunes a single model without employing any method @@ -42,7 +42,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: EvaluationPlugin = default_evaluator, eval_every=-1, ): """ diff --git a/avalanche/training/templates/__init__.py b/avalanche/training/templates/__init__.py index 191c78e28..5438e8d9b 100644 --- a/avalanche/training/templates/__init__.py +++ b/avalanche/training/templates/__init__.py @@ -11,6 +11,8 @@ """ from .base import BaseTemplate from .base_sgd import BaseSGDTemplate -from .base_online_sgd import BaseOnlineSGDTemplate -from .online_supervised import OnlineSupervisedTemplate -from .supervised import SupervisedTemplate +from .common_templates import ( + SupervisedTemplate, + SupervisedMetaLearningTemplate, + OnlineSupervisedTemplate +) diff --git a/avalanche/training/templates/base_online_sgd.py b/avalanche/training/templates/base_online_sgd.py index 508289a36..e69de29bb 100644 --- a/avalanche/training/templates/base_online_sgd.py +++ b/avalanche/training/templates/base_online_sgd.py @@ -1,388 +0,0 @@ -from typing import Iterable, Sequence, Optional, Union, List - -import torch -from torch.nn import Module -from torch.optim import Optimizer - -from avalanche.benchmarks import CLExperience, CLStream -from avalanche.core import BaseSGDPlugin -from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin -from avalanche.training.plugins.clock import Clock -from avalanche.training.plugins.evaluation import default_evaluator -from avalanche.training.templates.base import BaseTemplate, ExpSequence - -from typing import TYPE_CHECKING - -from avalanche.training.utils import trigger_plugins - -if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate - - -class BaseOnlineSGDTemplate(BaseTemplate): - """Base class for continual learning skeletons. - - **Training loop** - The training loop is organized as follows:: - - train - train_exp # for each experience - - **Evaluation loop** - The evaluation loop is organized as follows:: - - eval - eval_exp # for each experience - - """ - - PLUGIN_CLASS = BaseSGDPlugin - - def __init__( - self, - model: Module, - optimizer: Optimizer, - train_mb_size: int = 1, - train_passes: int = 1, - eval_mb_size: Optional[int] = 1, - device="cpu", - plugins: Optional[List["SupervisedPlugin"]] = None, - evaluator: EvaluationPlugin = default_evaluator(), - eval_every=-1, - peval_mode="experience", - ): - """Init. - - :param model: PyTorch model. - :param optimizer: PyTorch optimizer. - :param train_mb_size: mini-batch size for training. - :param train_passes: number of training passes. - :param eval_mb_size: mini-batch size for eval. - :param evaluator: (optional) instance of EvaluationPlugin for logging - and metric computations. None to remove logging. - :param eval_every: the frequency of the calls to `eval` inside the - training loop. -1 disables the evaluation. 0 means `eval` is called - only at the end of the learning experience. Values >0 mean - that `eval` is called every `eval_every` experience and at the end - of the learning experience. - :param peval_mode: one of {'experience', 'iteration'}. Decides whether - the periodic evaluation during training should execute every - `eval_every` experiences or iterations (Default='experience'). - """ - super().__init__(model=model, device=device, plugins=plugins) - - self.optimizer: Optimizer = optimizer - """ PyTorch optimizer. """ - - self.train_passes: int = train_passes - """ Number of training passes. """ - - self.train_mb_size: int = train_mb_size - """ Training mini-batch size. """ - - self.eval_mb_size: int = ( - train_mb_size if eval_mb_size is None else eval_mb_size - ) - """ Eval mini-batch size. """ - - if evaluator is None: - evaluator = EvaluationPlugin() - self.plugins.append(evaluator) - self.evaluator = evaluator - """ EvaluationPlugin used for logging and metric computations. """ - - # Configure periodic evaluation. - assert peval_mode in {"experience", "iteration"} - self.eval_every = eval_every - peval = PeriodicEval(eval_every, peval_mode) - self.plugins.append(peval) - - self.clock = Clock() - """ Incremental counters for strategy events. """ - # WARNING: Clock needs to be the last plugin, otherwise - # counters will be wrong for plugins called after it. - self.plugins.append(self.clock) - - ################################################################### - # State variables. These are updated during the train/eval loops. # - ################################################################### - - self.dataloader = None - """ Dataloader. """ - - self.mbatch = None - """ Current mini-batch. """ - - self.mb_output = None - """ Model's output computed on the current mini-batch. """ - - self.loss = None - """ Loss of the current mini-batch. """ - - self._stop_training = False - - def train( - self, - experiences: Union[CLExperience, ExpSequence], - eval_streams: Optional[ - Sequence[Union[CLExperience, ExpSequence]] - ] = None, - **kwargs - ): - super().train(experiences, eval_streams, **kwargs) - return self.evaluator.get_last_metrics() - - @torch.no_grad() - def eval(self, exp_list: Union[CLExperience, CLStream], **kwargs): - """ - Evaluate the current model on a series of experiences and - returns the last recorded value for each metric. - - :param exp_list: CL experience information. - :param kwargs: custom arguments. - - :return: dictionary containing last recorded value for - each metric name - """ - super().eval(exp_list, **kwargs) - return self.evaluator.get_last_metrics() - - def _before_training_exp(self, **kwargs): - self.make_train_dataloader(**kwargs) - # Model Adaptation (e.g. freeze/add new units) - - # If strategy has access to the task boundaries, and the current - # sub-experience is the first sub-experience in the online (sub-)stream, - # then adapt the model with the full origin experience: - if self.experience.access_task_boundaries: - if self.experience.is_first_subexp: - self.model = self.model_adaptation() - self.make_optimizer() - # Otherwise, adapt to the current sub-experience: - else: - self.model = self.model_adaptation() - self.make_optimizer() - - super()._before_training_exp(**kwargs) - - def _train_exp(self, experience: CLExperience, eval_streams=None, **kwargs): - """Training loop over a single Experience object. - - :param experience: CL experience information. - :param eval_streams: list of streams for evaluation. - If None: use the training experience for evaluation. - Use [] if you do not want to evaluate during training. - :param kwargs: custom arguments. - """ - if eval_streams is None: - eval_streams = [experience] - for i, exp in enumerate(eval_streams): - if not isinstance(exp, Iterable): - eval_streams[i] = [exp] - - self.training_pass(**kwargs) - - def _before_eval_exp(self, **kwargs): - self.make_eval_dataloader(**kwargs) - # Model Adaptation (e.g. freeze/add new units) - self.model = self.model_adaptation() - super()._before_eval_exp(**kwargs) - - def _eval_exp(self, **kwargs): - self.eval_epoch(**kwargs) - - def make_train_dataloader(self, **kwargs): - """Assign dataloader to self.dataloader.""" - raise NotImplementedError() - - def make_eval_dataloader(self, **kwargs): - """Assign dataloader to self.dataloader.""" - raise NotImplementedError() - - def make_optimizer(self, **kwargs): - """Optimizer initialization.""" - raise NotImplementedError() - - def criterion(self): - """Compute loss function.""" - raise NotImplementedError() - - def forward(self): - """Compute the model's output given the current mini-batch.""" - raise NotImplementedError() - - def model_adaptation(self, model=None): - """Adapts the model to the current experience.""" - raise NotImplementedError() - - def stop_training(self): - """Signals to stop training at the next iteration.""" - self._stop_training = True - - def training_pass(self, **kwargs): - """Training pass. - - :param kwargs: - :return: - """ - for self.pass_itr in range(self.train_passes): - for self.mbatch in self.dataloader: - if self._stop_training: - break - - self._unpack_minibatch() - self._before_training_iteration(**kwargs) - - self.optimizer.zero_grad() - self.loss = 0 - - # Forward - self._before_forward(**kwargs) - self.mb_output = self.forward() - self._after_forward(**kwargs) - - # Loss & Backward - self.loss += self.criterion() - - self._before_backward(**kwargs) - self.backward() - self._after_backward(**kwargs) - - # Optimization step - self._before_update(**kwargs) - self.optimizer_step() - self._after_update(**kwargs) - - self._after_training_iteration(**kwargs) - - def backward(self): - """Run the backward pass.""" - self.loss.backward() - - def optimizer_step(self): - """Execute the optimizer step (weights update).""" - self.optimizer.step() - - def eval_epoch(self, **kwargs): - """Evaluation loop over the current `self.dataloader`.""" - for self.mbatch in self.dataloader: - self._unpack_minibatch() - self._before_eval_iteration(**kwargs) - - self._before_eval_forward(**kwargs) - self.mb_output = self.forward() - self._after_eval_forward(**kwargs) - self.loss = self.criterion() - - self._after_eval_iteration(**kwargs) - - def _unpack_minibatch(self): - """Move to device""" - for i in range(len(self.mbatch)): - self.mbatch[i] = self.mbatch[i].to(self.device) - - ######################################################### - # Plugin Triggers # - ######################################################### - - def _before_training_iteration(self, **kwargs): - trigger_plugins(self, "before_training_iteration", **kwargs) - - def _before_forward(self, **kwargs): - trigger_plugins(self, "before_forward", **kwargs) - - def _after_forward(self, **kwargs): - trigger_plugins(self, "after_forward", **kwargs) - - def _before_backward(self, **kwargs): - trigger_plugins(self, "before_backward", **kwargs) - - def _after_backward(self, **kwargs): - trigger_plugins(self, "after_backward", **kwargs) - - def _after_training_iteration(self, **kwargs): - trigger_plugins(self, "after_training_iteration", **kwargs) - - def _before_update(self, **kwargs): - trigger_plugins(self, "before_update", **kwargs) - - def _after_update(self, **kwargs): - trigger_plugins(self, "after_update", **kwargs) - - def _before_eval_iteration(self, **kwargs): - trigger_plugins(self, "before_eval_iteration", **kwargs) - - def _before_eval_forward(self, **kwargs): - trigger_plugins(self, "before_eval_forward", **kwargs) - - def _after_eval_forward(self, **kwargs): - trigger_plugins(self, "after_eval_forward", **kwargs) - - def _after_eval_iteration(self, **kwargs): - trigger_plugins(self, "after_eval_iteration", **kwargs) - - -class PeriodicEval(SupervisedPlugin): - """Schedules periodic evaluation during training. - - This plugin is automatically configured and added by the BaseTemplate. - """ - - def __init__(self, eval_every=-1, peval_mode="experience", do_initial=True): - """Init. - - :param eval_every: the frequency of the calls to `eval` inside the - training loop. -1 disables the evaluation. 0 means `eval` is called - only at the end of the learning experience. Values >0 mean - that `eval` is called every `eval_every` experience and at the - end of the learning experience. - :param peval_mode: one of {'experience', 'iteration'}. Decides whether - the periodic evaluation during training should execute every - `eval_every` experience or iterations - (Default='experience'). - :param do_initial: whether to evaluate before each `train` call. - Occasionally needed becuase some metrics need to know the - accuracy before training. - """ - super().__init__() - assert peval_mode in {"experience", "iteration"} - self.eval_every = eval_every - self.peval_mode = peval_mode - self.do_initial = do_initial and eval_every > -1 - self.do_final = None - self._is_eval_updated = False - - def before_training(self, strategy, **kwargs): - """Eval before each learning experience. - - Occasionally needed because some metrics need the accuracy before - training. - """ - if self.do_initial: - self._peval(strategy, **kwargs) - - def _peval(self, strategy, **kwargs): - for el in strategy._eval_streams: - strategy.eval(el, **kwargs) - - def _maybe_peval(self, strategy, counter, **kwargs): - if self.eval_every > 0 and counter % self.eval_every == 0: - self._peval(strategy, **kwargs) - - def after_training_exp(self, strategy: "BaseOnlineSGDTemplate", **kwargs): - """Periodic eval controlled by `self.eval_every` and - `self.peval_mode`.""" - if self.peval_mode == "experience": - self._maybe_peval( - strategy, strategy.clock.train_exp_counter, **kwargs - ) - - def after_training_iteration( - self, strategy: "BaseOnlineSGDTemplate", **kwargs - ): - """Periodic eval controlled by `self.eval_every` and - `self.peval_mode`.""" - if self.peval_mode == "iteration": - self._maybe_peval( - strategy, strategy.clock.train_exp_iterations, **kwargs - ) diff --git a/avalanche/training/templates/base_sgd.py b/avalanche/training/templates/base_sgd.py index 1521db340..dc0ba9d38 100644 --- a/avalanche/training/templates/base_sgd.py +++ b/avalanche/training/templates/base_sgd.py @@ -1,8 +1,10 @@ from typing import Iterable, Sequence, Optional, Union, List +from pkg_resources import parse_version import torch -from torch.nn import Module +from torch.nn import Module, CrossEntropyLoss from torch.optim import Optimizer +from torch.utils.data import DataLoader from avalanche.benchmarks import CLExperience, CLStream from avalanche.core import BaseSGDPlugin @@ -10,17 +12,14 @@ from avalanche.training.plugins.clock import Clock from avalanche.training.plugins.evaluation import default_evaluator from avalanche.training.templates.base import BaseTemplate, ExpSequence - -from typing import TYPE_CHECKING - +from avalanche.models.utils import avalanche_model_adaptation +from avalanche.benchmarks.utils.data_loader import TaskBalancedDataLoader, \ + collate_from_data_or_kwargs from avalanche.training.utils import trigger_plugins -if TYPE_CHECKING: - from avalanche.training.templates.supervised import SupervisedTemplate - class BaseSGDTemplate(BaseTemplate): - """Base class for continual learning skeletons. + """Base SGD class for continual learning skeletons. **Training loop** The training loop is organized as follows:: @@ -42,6 +41,7 @@ def __init__( self, model: Module, optimizer: Optimizer, + criterion=CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = 1, @@ -55,6 +55,7 @@ def __init__( :param model: PyTorch model. :param optimizer: PyTorch optimizer. + :param criterion: loss function. :param train_mb_size: mini-batch size for training. :param train_epochs: number of training epochs. :param eval_mb_size: mini-batch size for eval. @@ -74,6 +75,9 @@ def __init__( self.optimizer: Optimizer = optimizer """ PyTorch optimizer. """ + self._criterion = criterion + """ Criterion. """ + self.train_epochs: int = train_epochs """ Number of training epochs. """ @@ -92,7 +96,7 @@ def __init__( """ EvaluationPlugin used for logging and metric computations. """ # Configure periodic evaluation. - assert peval_mode in {"epoch", "iteration"} + assert peval_mode in {"experience", "epoch", "iteration"} self.eval_every = eval_every peval = PeriodicEval(eval_every, peval_mode) self.plugins.append(peval) @@ -107,6 +111,17 @@ def __init__( # State variables. These are updated during the train/eval loops. # ################################################################### + self.adapted_dataset = None + """ Data used to train. It may be modified by plugins. Plugins can + append data to it (e.g. for replay). + + .. note:: + + This dataset may contain samples from different experiences. If you + want the original data for the current experience + use :attr:`.BaseTemplate.experience`. + """ + self.dataloader = None """ Dataloader. """ @@ -121,14 +136,13 @@ def __init__( self._stop_training = False - def train( - self, - experiences: Union[CLExperience, ExpSequence], - eval_streams: Optional[ - Sequence[Union[CLExperience, ExpSequence]] - ] = None, - **kwargs - ): + def train(self, + experiences: Union[CLExperience, + ExpSequence], + eval_streams: Optional[Sequence[Union[CLExperience, + ExpSequence]]] = None, + **kwargs): + super().train(experiences, eval_streams, **kwargs) return self.evaluator.get_last_metrics() @@ -147,56 +161,18 @@ def eval(self, exp_list: Union[CLExperience, CLStream], **kwargs): super().eval(exp_list, **kwargs) return self.evaluator.get_last_metrics() - def _before_training_exp(self, **kwargs): - self.make_train_dataloader(**kwargs) - # Model Adaptation (e.g. freeze/add new units) - self.model = self.model_adaptation() - self.make_optimizer() - super()._before_training_exp(**kwargs) - - def _train_exp(self, experience: CLExperience, eval_streams=None, **kwargs): - """Training loop over a single Experience object. - - :param experience: CL experience information. - :param eval_streams: list of streams for evaluation. - If None: use the training experience for evaluation. - Use [] if you do not want to evaluate during training. - :param kwargs: custom arguments. - """ - if eval_streams is None: - eval_streams = [experience] - for i, exp in enumerate(eval_streams): - if not isinstance(exp, Iterable): - eval_streams[i] = [exp] - for _ in range(self.train_epochs): - self._before_training_epoch(**kwargs) - - if self._stop_training: # Early stopping - self._stop_training = False - break - - self.training_epoch(**kwargs) - self._after_training_epoch(**kwargs) - - def _before_eval_exp(self, **kwargs): - self.make_eval_dataloader(**kwargs) - # Model Adaptation (e.g. freeze/add new units) - self.model = self.model_adaptation() - super()._before_eval_exp(**kwargs) + def _train_exp( + self, experience: CLExperience, eval_streams, **kwargs + ): + # Should be implemented in Observation Type + raise NotImplementedError() def _eval_exp(self, **kwargs): self.eval_epoch(**kwargs) - def make_train_dataloader(self, **kwargs): - """Assign dataloader to self.dataloader.""" - raise NotImplementedError() - - def make_eval_dataloader(self, **kwargs): - """Assign dataloader to self.dataloader.""" - raise NotImplementedError() - def make_optimizer(self, **kwargs): """Optimizer initialization.""" + # Should be implemented in Observation Type raise NotImplementedError() def criterion(self): @@ -216,39 +192,8 @@ def stop_training(self): self._stop_training = True def training_epoch(self, **kwargs): - """Training epoch. - - :param kwargs: - :return: - """ - for self.mbatch in self.dataloader: - if self._stop_training: - break - - self._unpack_minibatch() - self._before_training_iteration(**kwargs) - - self.optimizer.zero_grad() - self.loss = 0 - - # Forward - self._before_forward(**kwargs) - self.mb_output = self.forward() - self._after_forward(**kwargs) - - # Loss & Backward - self.loss += self.criterion() - - self._before_backward(**kwargs) - self.backward() - self._after_backward(**kwargs) - - # Optimization step - self._before_update(**kwargs) - self.optimizer_step() - self._after_update(**kwargs) - - self._after_training_iteration(**kwargs) + # Should be implemented in Update Type + raise NotADirectoryError() def backward(self): """Run the backward pass.""" @@ -271,8 +216,168 @@ def eval_epoch(self, **kwargs): self._after_eval_iteration(**kwargs) + # ==================================================================> NEW + + def check_model_and_optimizer(self): + # Should be implemented in observation type + raise NotImplementedError() + + def _before_training_exp(self, **kwargs): + """Setup to train on a single experience.""" + # Data Adaptation (e.g. add new samples/data augmentation) + self._before_train_dataset_adaptation(**kwargs) + self.train_dataset_adaptation(**kwargs) + self._after_train_dataset_adaptation(**kwargs) + + self.make_train_dataloader(**kwargs) + + # Model Adaptation (e.g. freeze/add new units) + # self.model = self.model_adaptation() + # self.make_optimizer() + self.check_model_and_optimizer() + + super()._before_training_exp(**kwargs) + + def _train_exp( + self, experience: CLExperience, eval_streams=None, **kwargs + ): + """Training loop over a single Experience object. + + :param experience: CL experience information. + :param eval_streams: list of streams for evaluation. + If None: use the training experience for evaluation. + Use [] if you do not want to evaluate during training. + :param kwargs: custom arguments. + """ + if eval_streams is None: + eval_streams = [experience] + for i, exp in enumerate(eval_streams): + if not isinstance(exp, Iterable): + eval_streams[i] = [exp] + for _ in range(self.train_epochs): + self._before_training_epoch(**kwargs) + + if self._stop_training: # Early stopping + self._stop_training = False + break + + self.training_epoch(**kwargs) + self._after_training_epoch(**kwargs) + + def _save_train_state(self): + """Save the training state which may be modified by the eval loop. + + This currently includes: experience, adapted_dataset, dataloader, + is_training, and train/eval modes for each module. + + TODO: we probably need a better way to do this. + """ + state = super()._save_train_state() + new_state = { + "adapted_dataset": self.adapted_dataset, + "dataloader": self.dataloader, + } + return {**state, **new_state} + + def train_dataset_adaptation(self, **kwargs): + """Initialize `self.adapted_dataset`.""" + self.adapted_dataset = self.experience.dataset + self.adapted_dataset = self.adapted_dataset.train() + + def _load_train_state(self, prev_state): + super()._load_train_state(prev_state) + self.adapted_dataset = prev_state["adapted_dataset"] + self.dataloader = prev_state["dataloader"] + + def _before_eval_exp(self, **kwargs): + + # Data Adaptation + self._before_eval_dataset_adaptation(**kwargs) + self.eval_dataset_adaptation(**kwargs) + self._after_eval_dataset_adaptation(**kwargs) + + self.make_eval_dataloader(**kwargs) + # Model Adaptation (e.g. freeze/add new units) + self.model = self.model_adaptation() + + super()._before_eval_exp(**kwargs) + + def make_train_dataloader( + self, + num_workers=0, + shuffle=True, + pin_memory=True, + persistent_workers=False, + **kwargs + ): + """Data loader initialization. + + Called at the start of each learning experience after the dataset + adaptation. + + :param num_workers: number of thread workers for the data loading. + :param shuffle: True if the data should be shuffled, False otherwise. + :param pin_memory: If True, the data loader will copy Tensors into CUDA + pinned memory before returning them. Defaults to True. + """ + + other_dataloader_args = {} + + if parse_version(torch.__version__) >= parse_version("1.7.0"): + other_dataloader_args["persistent_workers"] = persistent_workers + for k, v in kwargs.items(): + other_dataloader_args[k] = v + + self.dataloader = TaskBalancedDataLoader( + self.adapted_dataset, + oversample_small_groups=True, + num_workers=num_workers, + batch_size=self.train_mb_size, + shuffle=shuffle, + pin_memory=pin_memory, + **other_dataloader_args + ) + + def make_eval_dataloader( + self, num_workers=0, pin_memory=True, persistent_workers=False, **kwargs + ): + """ + Initializes the eval data loader. + :param num_workers: How many subprocesses to use for data loading. + 0 means that the data will be loaded in the main process. + (default: 0). + :param pin_memory: If True, the data loader will copy Tensors into CUDA + pinned memory before returning them. Defaults to True. + :param kwargs: + :return: + """ + other_dataloader_args = {} + + if parse_version(torch.__version__) >= parse_version("1.7.0"): + other_dataloader_args["persistent_workers"] = persistent_workers + for k, v in kwargs.items(): + other_dataloader_args[k] = v + + collate_from_data_or_kwargs(self.adapted_dataset, + other_dataloader_args) + self.dataloader = DataLoader( + self.adapted_dataset, + num_workers=num_workers, + batch_size=self.eval_mb_size, + pin_memory=pin_memory, + **other_dataloader_args + ) + + def eval_dataset_adaptation(self, **kwargs): + """Initialize `self.adapted_dataset`.""" + self.adapted_dataset = self.experience.dataset + self.adapted_dataset = self.adapted_dataset.eval() + def _unpack_minibatch(self): """Move to device""" + # First verify the mini-batch + self._check_minibatch() + for i in range(len(self.mbatch)): self.mbatch[i] = self.mbatch[i].to(self.device) @@ -322,6 +427,20 @@ def _after_eval_forward(self, **kwargs): def _after_eval_iteration(self, **kwargs): trigger_plugins(self, "after_eval_iteration", **kwargs) + # ==================================================================> NEW + + def _before_train_dataset_adaptation(self, **kwargs): + trigger_plugins(self, "before_train_dataset_adaptation", **kwargs) + + def _after_train_dataset_adaptation(self, **kwargs): + trigger_plugins(self, "after_train_dataset_adaptation", **kwargs) + + def _before_eval_dataset_adaptation(self, **kwargs): + trigger_plugins(self, "before_eval_dataset_adaptation", **kwargs) + + def _after_eval_dataset_adaptation(self, **kwargs): + trigger_plugins(self, "after_eval_dataset_adaptation", **kwargs) + class PeriodicEval(SupervisedPlugin): """Schedules periodic evaluation during training. @@ -345,7 +464,7 @@ def __init__(self, eval_every=-1, peval_mode="epoch", do_initial=True): accuracy before training. """ super().__init__() - assert peval_mode in {"epoch", "iteration"} + assert peval_mode in {"experience", "epoch", "iteration"} self.eval_every = eval_every self.peval_mode = peval_mode self.do_initial = do_initial and eval_every > -1 @@ -378,11 +497,6 @@ def before_training_exp(self, strategy, **kwargs): pass self.do_final = self.do_final and self.eval_every > -1 - def after_training_exp(self, strategy, **kwargs): - """Final eval after a learning experience.""" - if self.do_final: - self._peval(strategy, **kwargs) - def _peval(self, strategy, **kwargs): for el in strategy._eval_streams: strategy.eval(el, **kwargs) @@ -391,18 +505,31 @@ def _maybe_peval(self, strategy, counter, **kwargs): if self.eval_every > 0 and counter % self.eval_every == 0: self._peval(strategy, **kwargs) - def after_training_epoch(self, strategy: "BaseSGDTemplate", **kwargs): + def after_training_epoch(self, strategy: "BaseSGDTemplate", + **kwargs): """Periodic eval controlled by `self.eval_every` and `self.peval_mode`.""" if self.peval_mode == "epoch": - self._maybe_peval( - strategy, strategy.clock.train_exp_epochs, **kwargs - ) + self._maybe_peval(strategy, strategy.clock.train_exp_epochs, + **kwargs) - def after_training_iteration(self, strategy: "BaseSGDTemplate", **kwargs): + def after_training_iteration(self, strategy: "BaseSGDTemplate", + **kwargs): """Periodic eval controlled by `self.eval_every` and `self.peval_mode`.""" if self.peval_mode == "iteration": - self._maybe_peval( - strategy, strategy.clock.train_exp_iterations, **kwargs - ) + self._maybe_peval(strategy, strategy.clock.train_exp_iterations, + **kwargs) + + # ---> New + def after_training_exp(self, strategy, **kwargs): + """Final eval after a learning experience.""" + if self.do_final: + self._peval(strategy, **kwargs) + + # def after_training_exp(self, strategy: "BaseOnlineSGDTemplate", **kwargs): + # """Periodic eval controlled by `self.eval_every` and + # `self.peval_mode`.""" + # if self.peval_mode == "experience": + # self._maybe_peval(strategy, strategy.clock.train_exp_counter, + # **kwargs) diff --git a/avalanche/training/templates/common_templates.py b/avalanche/training/templates/common_templates.py new file mode 100644 index 000000000..a8f5e35e6 --- /dev/null +++ b/avalanche/training/templates/common_templates.py @@ -0,0 +1,340 @@ +from typing import Sequence, Optional + +from torch.nn import Module, CrossEntropyLoss +from torch.optim import Optimizer + +from avalanche.core import BaseSGDPlugin +from avalanche.training.plugins.evaluation import default_evaluator + +from .observation_type import * +from .problem_type import * +from .update_type import * +from .base_sgd import BaseSGDTemplate + + +class SupervisedTemplate(BatchObservation, SupervisedProblem, SGDUpdate, + BaseSGDTemplate): + """Base class for continual learning strategies. + + BaseTemplate is the super class of all task-based continual learning + strategies. It implements a basic training loop and callback system + that allows to execute code at each experience of the training loop. + Plugins can be used to implement callbacks to augment the training + loop with additional behavior (e.g. a memory buffer for replay). + + **Scenarios** + This strategy supports several continual learning scenarios: + + * class-incremental scenarios (no task labels) + * multi-task scenarios, where task labels are provided) + * multi-incremental scenarios, where the same task may be revisited + + The exact scenario depends on the data stream and whether it provides + the task labels. + + **Training loop** + The training loop is organized as follows:: + + train + train_exp # for each experience + adapt_train_dataset + train_dataset_adaptation + make_train_dataloader + train_epoch # for each epoch + # forward + # backward + # model update + + **Evaluation loop** + The evaluation loop is organized as follows:: + + eval + eval_exp # for each experience + adapt_eval_dataset + eval_dataset_adaptation + make_eval_dataloader + eval_epoch # for each epoch + # forward + # backward + # model update + + """ + + PLUGIN_CLASS = BaseSGDPlugin + + def __init__( + self, + model: Module, + optimizer: Optimizer, + criterion=CrossEntropyLoss(), + train_mb_size: int = 1, + train_epochs: int = 1, + eval_mb_size: Optional[int] = 1, + device="cpu", + plugins: Optional[Sequence["BaseSGDPlugin"]] = None, + evaluator=default_evaluator, + eval_every=-1, + peval_mode="epoch", + ): + """Init. + + :param model: PyTorch model. + :param optimizer: PyTorch optimizer. + :param criterion: loss function. + :param train_mb_size: mini-batch size for training. + :param train_epochs: number of training epochs. + :param eval_mb_size: mini-batch size for eval. + :param device: PyTorch device where the model will be allocated. + :param plugins: (optional) list of StrategyPlugins. + :param evaluator: (optional) instance of EvaluationPlugin for logging + and metric computations. None to remove logging. + :param eval_every: the frequency of the calls to `eval` inside the + training loop. -1 disables the evaluation. 0 means `eval` is called + only at the end of the learning experience. Values >0 mean that + `eval` is called every `eval_every` epochs and at the end of the + learning experience. + :param peval_mode: one of {'epoch', 'iteration'}. Decides whether the + periodic evaluation during training should execute every + `eval_every` epochs or iterations (Default='epoch'). + """ + super().__init__( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + peval_mode=peval_mode, + ) + ################################################################### + # State variables. These are updated during the train/eval loops. # + ################################################################### + + # self.adapted_dataset = None + # """ Data used to train. It may be modified by plugins. Plugins can + # append data to it (e.g. for replay). + # + # .. note:: + # + # This dataset may contain samples from different experiences. If you + # want the original data for the current experience + # use :attr:`.BaseTemplate.experience`. + + +class SupervisedMetaLearningTemplate(BatchObservation, SupervisedProblem, + MetaUpdate, BaseSGDTemplate): + """Base class for continual learning strategies. + + BaseTemplate is the super class of all task-based continual learning + strategies. It implements a basic training loop and callback system + that allows to execute code at each experience of the training loop. + Plugins can be used to implement callbacks to augment the training + loop with additional behavior (e.g. a memory buffer for replay). + + **Scenarios** + This strategy supports several continual learning scenarios: + + * class-incremental scenarios (no task labels) + * multi-task scenarios, where task labels are provided) + * multi-incremental scenarios, where the same task may be revisited + + The exact scenario depends on the data stream and whether it provides + the task labels. + + **Training loop** + The training loop is organized as follows:: + + train + train_exp # for each experience + adapt_train_dataset + train_dataset_adaptation + make_train_dataloader + train_epoch # for each epoch + # forward + # backward + # model update + + **Evaluation loop** + The evaluation loop is organized as follows:: + + eval + eval_exp # for each experience + adapt_eval_dataset + eval_dataset_adaptation + make_eval_dataloader + eval_epoch # for each epoch + # forward + # backward + # model update + + """ + + PLUGIN_CLASS = BaseSGDPlugin + + def __init__( + self, + model: Module, + optimizer: Optimizer, + criterion=CrossEntropyLoss(), + train_mb_size: int = 1, + train_epochs: int = 1, + eval_mb_size: Optional[int] = 1, + device="cpu", + plugins: Optional[Sequence["BaseSGDPlugin"]] = None, + evaluator=default_evaluator, + eval_every=-1, + peval_mode="epoch", + ): + """Init. + + :param model: PyTorch model. + :param optimizer: PyTorch optimizer. + :param criterion: loss function. + :param train_mb_size: mini-batch size for training. + :param train_epochs: number of training epochs. + :param eval_mb_size: mini-batch size for eval. + :param device: PyTorch device where the model will be allocated. + :param plugins: (optional) list of StrategyPlugins. + :param evaluator: (optional) instance of EvaluationPlugin for logging + and metric computations. None to remove logging. + :param eval_every: the frequency of the calls to `eval` inside the + training loop. -1 disables the evaluation. 0 means `eval` is called + only at the end of the learning experience. Values >0 mean that + `eval` is called every `eval_every` epochs and at the end of the + learning experience. + :param peval_mode: one of {'epoch', 'iteration'}. Decides whether the + periodic evaluation during training should execute every + `eval_every` epochs or iterations (Default='epoch'). + """ + super().__init__( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + peval_mode=peval_mode, + ) + ################################################################### + # State variables. These are updated during the train/eval loops. # + ################################################################### + + # self.adapted_dataset = None + # """ Data used to train. It may be modified by plugins. Plugins can + # append data to it (e.g. for replay). + # + # .. note:: + # + # This dataset may contain samples from different experiences. If you + # want the original data for the current experience + # use :attr:`.BaseTemplate.experience`. + + +class OnlineSupervisedTemplate(OnlineObservation, SupervisedProblem, SGDUpdate, + BaseSGDTemplate): + """Base class for continual learning strategies. + + BaseTemplate is the super class of all task-based continual learning + strategies. It implements a basic training loop and callback system + that allows to execute code at each experience of the training loop. + Plugins can be used to implement callbacks to augment the training + loop with additional behavior (e.g. a memory buffer for replay). + + **Scenarios** + This strategy supports several continual learning scenarios: + + * class-incremental scenarios (no task labels) + * multi-task scenarios, where task labels are provided) + * multi-incremental scenarios, where the same task may be revisited + + The exact scenario depends on the data stream and whether it provides + the task labels. + + **Training loop** + The training loop is organized as follows:: + + train + train_exp # for each experience + adapt_train_dataset + train_dataset_adaptation + make_train_dataloader + train_pass # for each pass + # forward + # backward + # model update + + **Evaluation loop** + The evaluation loop is organized as follows:: + + eval + eval_exp # for each experience + adapt_eval_dataset + eval_dataset_adaptation + make_eval_dataloader + eval_epoch # for each epoch + # forward + # backward + # model update + + """ + + PLUGIN_CLASS = BaseSGDPlugin + + def __init__( + self, + model: Module, + optimizer: Optimizer, + criterion=CrossEntropyLoss(), + train_mb_size: int = 1, + train_passes: int = 1, + eval_mb_size: Optional[int] = 1, + device="cpu", + plugins: Optional[Sequence["BaseSGDPlugin"]] = None, + evaluator=default_evaluator, + eval_every=-1, + peval_mode="experience", + ): + """Init. + + :param model: PyTorch model. + :param optimizer: PyTorch optimizer. + :param criterion: loss function. + :param train_mb_size: mini-batch size for training. + :param train_passes: number of training passes. + :param eval_mb_size: mini-batch size for eval. + :param device: PyTorch device where the model will be allocated. + :param plugins: (optional) list of StrategyPlugins. + :param evaluator: (optional) instance of EvaluationPlugin for logging + and metric computations. None to remove logging. + :param eval_every: the frequency of the calls to `eval` inside the + training loop. -1 disables the evaluation. 0 means `eval` is called + only at the end of the learning experience. Values >0 mean that + `eval` is called every `eval_every` experiences and at the end of + the learning experience. + :param peval_mode: one of {'experience', 'iteration'}. Decides whether + the periodic evaluation during training should execute every + `eval_every` experience or iterations (Default='experience'). + """ + super().__init__( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_passes, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + peval_mode=peval_mode, + ) + + self.train_passes = train_passes diff --git a/avalanche/training/templates/observation_type/__init__.py b/avalanche/training/templates/observation_type/__init__.py new file mode 100644 index 000000000..4391bfbd7 --- /dev/null +++ b/avalanche/training/templates/observation_type/__init__.py @@ -0,0 +1,6 @@ +"""Observation types mainly define the way data samples are observed: + batch(multiple epochs) vs. online(one epoch) + +""" +from .batch_observation import BatchObservation +from .online_observation import OnlineObservation diff --git a/avalanche/training/templates/observation_type/batch_observation.py b/avalanche/training/templates/observation_type/batch_observation.py new file mode 100644 index 000000000..4ec073849 --- /dev/null +++ b/avalanche/training/templates/observation_type/batch_observation.py @@ -0,0 +1,31 @@ +from typing import Iterable + +from avalanche.benchmarks import CLExperience +from avalanche.models.dynamic_optimizers import reset_optimizer +from avalanche.models.utils import avalanche_model_adaptation + + +class BatchObservation: + def model_adaptation(self, model=None): + """Adapts the model to the current data. + + Calls the :class:`~avalanche.models.DynamicModule`s adaptation. + """ + if model is None: + model = self.model + avalanche_model_adaptation(model, self.experience) + return model.to(self.device) + + def make_optimizer(self): + """Optimizer initialization. + + Called before each training experiene to configure the optimizer. + """ + # we reset the optimizer's state after each experience. + # This allows to add new parameters (new heads) and + # freezing old units during the model's adaptation phase. + reset_optimizer(self.optimizer, self.model) + + def check_model_and_optimizer(self): + self.model = self.model_adaptation() + self.make_optimizer() diff --git a/avalanche/training/templates/observation_type/online_observation.py b/avalanche/training/templates/observation_type/online_observation.py new file mode 100644 index 000000000..d3dbfaac5 --- /dev/null +++ b/avalanche/training/templates/observation_type/online_observation.py @@ -0,0 +1,66 @@ +from typing import Iterable + +from avalanche.benchmarks import OnlineCLExperience +from avalanche.models.dynamic_optimizers import reset_optimizer +from avalanche.models.dynamic_optimizers import update_optimizer +from avalanche.models.utils import avalanche_model_adaptation + + +class OnlineObservation: + def make_optimizer(self): + """Optimizer initialization. + + Called before each training experience to configure the optimizer. + """ + # We reset the optimizer's state after each experience if task + # boundaries are given, otherwise it updates the optimizer only if + # new parameters are added to the model after each adaptation step. + + # We assume the current experience is an OnlineCLExperience: + if self.experience.access_task_boundaries: + reset_optimizer(self.optimizer, self.model) + + else: + update_optimizer(self.optimizer, + self.model_params_before_adaptation, + self.model.parameters(), + reset_state=False) + + def model_adaptation(self, model=None): + """Adapts the model to the current data. + + Calls the :class:`~avalanche.models.DynamicModule`s adaptation. + """ + if model is None: + model = self.model + + # For training: + if isinstance(self.experience, OnlineCLExperience): + # If the strategy has access to task boundaries, adapt the model + # for the whole origin experience to add the + if self.experience.access_task_boundaries: + avalanche_model_adaptation(model, + self.experience.origin_experience) + else: + self.model_params_before_adaptation = list(model.parameters()) + avalanche_model_adaptation(model, self.experience) + + # For evaluation, the experience is not necessarily an online + # experience: + else: + avalanche_model_adaptation(model, self.experience) + + return model.to(self.device) + + def check_model_and_optimizer(self): + # If strategy has access to the task boundaries, and the current + # sub-experience is the first sub-experience in the online (sub-)stream, + # then adapt the model with the full origin experience: + if self.experience.access_task_boundaries: + if self.experience.is_first_subexp: + self.model = self.model_adaptation() + self.make_optimizer() + # Otherwise, adapt to the current sub-experience: + else: + self.model = self.model_adaptation() + self.make_optimizer() diff --git a/avalanche/training/templates/online_supervised.py b/avalanche/training/templates/online_supervised.py deleted file mode 100644 index 7eb4b895f..000000000 --- a/avalanche/training/templates/online_supervised.py +++ /dev/null @@ -1,343 +0,0 @@ -from typing import Sequence, Optional -from pkg_resources import parse_version - -import torch -from torch.nn import Module, CrossEntropyLoss -from torch.optim import Optimizer -from torch.utils.data import DataLoader - -from avalanche.benchmarks.utils.data_loader import ( - TaskBalancedDataLoader, - collate_from_data_or_kwargs, -) -from avalanche.models import avalanche_forward -from avalanche.models.dynamic_optimizers import reset_optimizer -from avalanche.models.utils import avalanche_model_adaptation -from avalanche.training.plugins import SupervisedPlugin -from avalanche.training.plugins.evaluation import default_evaluator -from avalanche.training.templates.base_online_sgd import BaseOnlineSGDTemplate -from avalanche.training.utils import trigger_plugins -from avalanche.benchmarks.scenarios import OnlineCLExperience -from avalanche.models.dynamic_optimizers import update_optimizer - - -class OnlineSupervisedTemplate(BaseOnlineSGDTemplate): - """Base class for continual learning strategies. - - BaseTemplate is the super class of all task-based continual learning - strategies. It implements a basic training loop and callback system - that allows to execute code at each experience of the training loop. - Plugins can be used to implement callbacks to augment the training - loop with additional behavior (e.g. a memory buffer for replay). - - **Scenarios** - This strategy supports several continual learning scenarios: - - * class-incremental scenarios (no task labels) - * multi-task scenarios, where task labels are provided) - * multi-incremental scenarios, where the same task may be revisited - - The exact scenario depends on the data stream and whether it provides - the task labels. - - **Training loop** - The training loop is organized as follows:: - - train - train_exp # for each experience - adapt_train_dataset - train_dataset_adaptation - make_train_dataloader - train_pass # for each pass - # forward - # backward - # model update - - **Evaluation loop** - The evaluation loop is organized as follows:: - - eval - eval_exp # for each experience - adapt_eval_dataset - eval_dataset_adaptation - make_eval_dataloader - eval_epoch # for each epoch - # forward - # backward - # model update - - """ - - PLUGIN_CLASS = SupervisedPlugin - - def __init__( - self, - model: Module, - optimizer: Optimizer, - criterion=CrossEntropyLoss(), - train_mb_size: int = 1, - train_passes: int = 1, - eval_mb_size: Optional[int] = 1, - device="cpu", - plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator=default_evaluator(), - eval_every=-1, - peval_mode="experience", - ): - """Init. - - :param model: PyTorch model. - :param optimizer: PyTorch optimizer. - :param criterion: loss function. - :param train_mb_size: mini-batch size for training. - :param train_passes: number of training passes. - :param eval_mb_size: mini-batch size for eval. - :param device: PyTorch device where the model will be allocated. - :param plugins: (optional) list of StrategyPlugins. - :param evaluator: (optional) instance of EvaluationPlugin for logging - and metric computations. None to remove logging. - :param eval_every: the frequency of the calls to `eval` inside the - training loop. -1 disables the evaluation. 0 means `eval` is called - only at the end of the learning experience. Values >0 mean that - `eval` is called every `eval_every` experiences and at the end of - the learning experience. - :param peval_mode: one of {'experience', 'iteration'}. Decides whether - the periodic evaluation during training should execute every - `eval_every` experience or iterations (Default='experience'). - """ - super().__init__( - model=model, - optimizer=optimizer, - train_mb_size=train_mb_size, - train_passes=train_passes, - eval_mb_size=eval_mb_size, - device=device, - plugins=plugins, - evaluator=evaluator, - eval_every=eval_every, - peval_mode=peval_mode, - ) - self._criterion = criterion - - ################################################################### - # State variables. These are updated during the train/eval loops. # - ################################################################### - - self.adapted_dataset = None - """ Data used to train. It may be modified by plugins. Plugins can - append data to it (e.g. for replay). - - .. note:: - - This dataset may contain samples from different experiences. If you - want the original data for the current experience - use :attr:`.BaseTemplate.experience`. - """ - - @property - def mb_x(self): - """Current mini-batch input.""" - return self.mbatch[0] - - @property - def mb_y(self): - """Current mini-batch target.""" - return self.mbatch[1] - - @property - def mb_task_id(self): - """Current mini-batch task labels.""" - assert len(self.mbatch) >= 3 - return self.mbatch[-1] - - def criterion(self): - """Loss function.""" - return self._criterion(self.mb_output, self.mb_y) - - def _before_training_exp(self, **kwargs): - """Setup to train on a single experience.""" - # Data Adaptation (e.g. add new samples/data augmentation) - self._before_train_dataset_adaptation(**kwargs) - self.train_dataset_adaptation(**kwargs) - self._after_train_dataset_adaptation(**kwargs) - super()._before_training_exp(**kwargs) - - def _load_train_state(self, prev_state): - super()._load_train_state(prev_state) - self.adapted_dataset = prev_state["adapted_dataset"] - self.dataloader = prev_state["dataloader"] - - def _save_train_state(self): - """Save the training state which may be modified by the eval loop. - - This currently includes: experience, adapted_dataset, dataloader, - is_training, and train/eval modes for each module. - - TODO: we probably need a better way to do this. - """ - state = super()._save_train_state() - new_state = { - "adapted_dataset": self.adapted_dataset, - "dataloader": self.dataloader, - } - return {**state, **new_state} - - def train_dataset_adaptation(self, **kwargs): - """Initialize `self.adapted_dataset`.""" - self.adapted_dataset = self.experience.dataset - self.adapted_dataset = self.adapted_dataset.train() - - def _before_eval_exp(self, **kwargs): - # Data Adaptation - self._before_eval_dataset_adaptation(**kwargs) - self.eval_dataset_adaptation(**kwargs) - self._after_eval_dataset_adaptation(**kwargs) - super()._before_eval_exp(**kwargs) - - def make_train_dataloader( - self, - num_workers=0, - shuffle=True, - pin_memory=True, - persistent_workers=False, - **kwargs - ): - """Data loader initialization. - - Called at the start of each learning experience after the dataset - adaptation. - - :param num_workers: number of thread workers for the data loading. - :param shuffle: True if the data should be shuffled, False otherwise. - :param pin_memory: If True, the data loader will copy Tensors into CUDA - pinned memory before returning them. Defaults to True. - """ - - other_dataloader_args = {} - - if parse_version(torch.__version__) >= parse_version("1.7.0"): - other_dataloader_args["persistent_workers"] = persistent_workers - for k, v in kwargs.items(): - other_dataloader_args[k] = v - - collate_from_data_or_kwargs(self.adapted_dataset, other_dataloader_args) - self.dataloader = TaskBalancedDataLoader( - self.adapted_dataset, - oversample_small_groups=True, - num_workers=num_workers, - batch_size=self.train_mb_size, - shuffle=shuffle, - pin_memory=pin_memory, - **other_dataloader_args - ) - - def make_eval_dataloader( - self, num_workers=0, pin_memory=True, persistent_workers=False, **kwargs - ): - """ - Initializes the eval data loader. - :param num_workers: How many subprocesses to use for data loading. - 0 means that the data will be loaded in the main process. - (default: 0). - :param pin_memory: If True, the data loader will copy Tensors into CUDA - pinned memory before returning them. Defaults to True. - :param kwargs: - :return: - """ - other_dataloader_args = {} - - if parse_version(torch.__version__) >= parse_version("1.7.0"): - other_dataloader_args["persistent_workers"] = persistent_workers - for k, v in kwargs.items(): - other_dataloader_args[k] = v - - collate_from_data_or_kwargs(self.adapted_dataset, other_dataloader_args) - self.dataloader = DataLoader( - self.adapted_dataset, - num_workers=num_workers, - batch_size=self.eval_mb_size, - pin_memory=pin_memory, - **other_dataloader_args - ) - - def forward(self): - """Compute the model's output given the current mini-batch.""" - return avalanche_forward(self.model, self.mb_x, self.mb_task_id) - - def model_adaptation(self, model=None): - """Adapts the model to the current data. - - Calls the :class:`~avalanche.models.DynamicModule`s adaptation. - """ - if model is None: - model = self.model - - # For training: - if isinstance(self.experience, OnlineCLExperience): - # If the strategy has access to task boundaries, adapt the model - # for the whole origin experience to add the - if self.experience.access_task_boundaries: - avalanche_model_adaptation( - model, self.experience.origin_experience - ) - else: - self.model_params_before_adaptation = list(model.parameters()) - avalanche_model_adaptation(model, self.experience) - - # For evaluation, the experience is not necessarily an online - # experience: - else: - avalanche_model_adaptation(model, self.experience) - - return model.to(self.device) - - def _unpack_minibatch(self): - """We assume mini-batches have the form . - This allows for arbitrary tensors between y and t. - Keep in mind that in the most general case mb_task_id is a tensor - which may contain different labels for each sample. - """ - assert len(self.mbatch) >= 3 - super()._unpack_minibatch() - - def eval_dataset_adaptation(self, **kwargs): - """Initialize `self.adapted_dataset`.""" - self.adapted_dataset = self.experience.dataset - self.adapted_dataset = self.adapted_dataset.eval() - - def make_optimizer(self): - """Optimizer initialization. - - Called before each training experience to configure the optimizer. - """ - # We reset the optimizer's state after each experience if task - # boundaries are given, otherwise it updates the optimizer only if - # new parameters are added to the model after each adaptation step. - - # We assume the current experience is an OnlineCLExperience: - if self.experience.access_task_boundaries: - reset_optimizer(self.optimizer, self.model) - - else: - update_optimizer( - self.optimizer, - self.model_params_before_adaptation, - self.model.parameters(), - reset_state=False, - ) - - ######################################################### - # Plugin Triggers # - ######################################################### - - def _before_train_dataset_adaptation(self, **kwargs): - trigger_plugins(self, "before_train_dataset_adaptation", **kwargs) - - def _after_train_dataset_adaptation(self, **kwargs): - trigger_plugins(self, "after_train_dataset_adaptation", **kwargs) - - def _before_eval_dataset_adaptation(self, **kwargs): - trigger_plugins(self, "before_eval_dataset_adaptation", **kwargs) - - def _after_eval_dataset_adaptation(self, **kwargs): - trigger_plugins(self, "after_eval_dataset_adaptation", **kwargs) diff --git a/avalanche/training/templates/problem_type/__init__.py b/avalanche/training/templates/problem_type/__init__.py new file mode 100644 index 000000000..0932beb4c --- /dev/null +++ b/avalanche/training/templates/problem_type/__init__.py @@ -0,0 +1,5 @@ +"""Problem types mainly define the properties and criterions depending on + how inputs should be mapped to outputs. + +""" +from .supervised_problem import SupervisedProblem diff --git a/avalanche/training/templates/problem_type/supervised_problem.py b/avalanche/training/templates/problem_type/supervised_problem.py new file mode 100644 index 000000000..9432e04ef --- /dev/null +++ b/avalanche/training/templates/problem_type/supervised_problem.py @@ -0,0 +1,31 @@ +from avalanche.models import avalanche_forward + + +class SupervisedProblem: + @property + def mb_x(self): + """Current mini-batch input.""" + return self.mbatch[0] + + @property + def mb_y(self): + """Current mini-batch target.""" + return self.mbatch[1] + + @property + def mb_task_id(self): + """Current mini-batch task labels.""" + assert len(self.mbatch) >= 3 + return self.mbatch[-1] + + def criterion(self): + """Loss function for supervised problems.""" + return self._criterion(self.mb_output, self.mb_y) + + def forward(self): + """Compute the model's output given the current mini-batch.""" + return avalanche_forward(self.model, self.mb_x, self.mb_task_id) + + def _check_minibatch(self): + """Check if the current mini-batch has 3 components.""" + assert len(self.mbatch) >= 3 diff --git a/avalanche/training/templates/supervised.py b/avalanche/training/templates/supervised.py deleted file mode 100644 index d65a60be8..000000000 --- a/avalanche/training/templates/supervised.py +++ /dev/null @@ -1,312 +0,0 @@ -from typing import Sequence, Optional -from pkg_resources import parse_version - -import torch -from torch.nn import Module, CrossEntropyLoss -from torch.optim import Optimizer -from torch.utils.data import DataLoader - -from avalanche.benchmarks.utils.data_loader import ( - TaskBalancedDataLoader, - collate_from_data_or_kwargs, -) -from avalanche.models import avalanche_forward -from avalanche.models.dynamic_optimizers import reset_optimizer -from avalanche.models.utils import avalanche_model_adaptation -from avalanche.training.plugins import SupervisedPlugin -from avalanche.training.plugins.evaluation import default_evaluator -from avalanche.training.templates.base_sgd import BaseSGDTemplate -from avalanche.training.utils import trigger_plugins - - -class SupervisedTemplate(BaseSGDTemplate): - """Base class for continual learning strategies. - - BaseTemplate is the super class of all task-based continual learning - strategies. It implements a basic training loop and callback system - that allows to execute code at each experience of the training loop. - Plugins can be used to implement callbacks to augment the training - loop with additional behavior (e.g. a memory buffer for replay). - - **Scenarios** - This strategy supports several continual learning scenarios: - - * class-incremental scenarios (no task labels) - * multi-task scenarios, where task labels are provided) - * multi-incremental scenarios, where the same task may be revisited - - The exact scenario depends on the data stream and whether it provides - the task labels. - - **Training loop** - The training loop is organized as follows:: - - train - train_exp # for each experience - adapt_train_dataset - train_dataset_adaptation - make_train_dataloader - train_epoch # for each epoch - # forward - # backward - # model update - - **Evaluation loop** - The evaluation loop is organized as follows:: - - eval - eval_exp # for each experience - adapt_eval_dataset - eval_dataset_adaptation - make_eval_dataloader - eval_epoch # for each epoch - # forward - # backward - # model update - - """ - - PLUGIN_CLASS = SupervisedPlugin - - def __init__( - self, - model: Module, - optimizer: Optimizer, - criterion=CrossEntropyLoss(), - train_mb_size: int = 1, - train_epochs: int = 1, - eval_mb_size: Optional[int] = 1, - device="cpu", - plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator=default_evaluator(), - eval_every=-1, - peval_mode="epoch", - ): - """Init. - - :param model: PyTorch model. - :param optimizer: PyTorch optimizer. - :param criterion: loss function. - :param train_mb_size: mini-batch size for training. - :param train_epochs: number of training epochs. - :param eval_mb_size: mini-batch size for eval. - :param device: PyTorch device where the model will be allocated. - :param plugins: (optional) list of StrategyPlugins. - :param evaluator: (optional) instance of EvaluationPlugin for logging - and metric computations. None to remove logging. - :param eval_every: the frequency of the calls to `eval` inside the - training loop. -1 disables the evaluation. 0 means `eval` is called - only at the end of the learning experience. Values >0 mean that - `eval` is called every `eval_every` epochs and at the end of the - learning experience. - :param peval_mode: one of {'epoch', 'iteration'}. Decides whether the - periodic evaluation during training should execute every - `eval_every` epochs or iterations (Default='epoch'). - """ - super().__init__( - model=model, - optimizer=optimizer, - train_mb_size=train_mb_size, - train_epochs=train_epochs, - eval_mb_size=eval_mb_size, - device=device, - plugins=plugins, - evaluator=evaluator, - eval_every=eval_every, - peval_mode=peval_mode, - ) - self._criterion = criterion - - ################################################################### - # State variables. These are updated during the train/eval loops. # - ################################################################### - - self.adapted_dataset = None - """ Data used to train. It may be modified by plugins. Plugins can - append data to it (e.g. for replay). - - .. note:: - - This dataset may contain samples from different experiences. If you - want the original data for the current experience - use :attr:`.BaseTemplate.experience`. - """ - - @property - def mb_x(self): - """Current mini-batch input.""" - return self.mbatch[0] - - @property - def mb_y(self): - """Current mini-batch target.""" - return self.mbatch[1] - - @property - def mb_task_id(self): - """Current mini-batch task labels.""" - assert len(self.mbatch) >= 3 - return self.mbatch[-1] - - def criterion(self): - """Loss function.""" - return self._criterion(self.mb_output, self.mb_y) - - def _before_training_exp(self, **kwargs): - """Setup to train on a single experience.""" - # Data Adaptation (e.g. add new samples/data augmentation) - self._before_train_dataset_adaptation(**kwargs) - self.train_dataset_adaptation(**kwargs) - self._after_train_dataset_adaptation(**kwargs) - super()._before_training_exp(**kwargs) - - def _load_train_state(self, prev_state): - super()._load_train_state(prev_state) - self.adapted_dataset = prev_state["adapted_dataset"] - self.dataloader = prev_state["dataloader"] - - def _save_train_state(self): - """Save the training state which may be modified by the eval loop. - - This currently includes: experience, adapted_dataset, dataloader, - is_training, and train/eval modes for each module. - - TODO: we probably need a better way to do this. - """ - state = super()._save_train_state() - new_state = { - "adapted_dataset": self.adapted_dataset, - "dataloader": self.dataloader, - } - return {**state, **new_state} - - def train_dataset_adaptation(self, **kwargs): - """Initialize `self.adapted_dataset`.""" - self.adapted_dataset = self.experience.dataset - self.adapted_dataset = self.adapted_dataset.train() - - def _before_eval_exp(self, **kwargs): - # Data Adaptation - self._before_eval_dataset_adaptation(**kwargs) - self.eval_dataset_adaptation(**kwargs) - self._after_eval_dataset_adaptation(**kwargs) - super()._before_eval_exp(**kwargs) - - def make_train_dataloader( - self, - num_workers=0, - shuffle=True, - pin_memory=True, - persistent_workers=False, - **kwargs - ): - """Data loader initialization. - - Called at the start of each learning experience after the dataset - adaptation. - - :param num_workers: number of thread workers for the data loading. - :param shuffle: True if the data should be shuffled, False otherwise. - :param pin_memory: If True, the data loader will copy Tensors into CUDA - pinned memory before returning them. Defaults to True. - """ - - other_dataloader_args = {} - - if parse_version(torch.__version__) >= parse_version("1.7.0"): - other_dataloader_args["persistent_workers"] = persistent_workers - for k, v in kwargs.items(): - other_dataloader_args[k] = v - - self.dataloader = TaskBalancedDataLoader( - self.adapted_dataset, - oversample_small_groups=True, - num_workers=num_workers, - batch_size=self.train_mb_size, - shuffle=shuffle, - pin_memory=pin_memory, - **other_dataloader_args - ) - - def make_eval_dataloader( - self, num_workers=0, pin_memory=True, persistent_workers=False, **kwargs - ): - """ - Initializes the eval data loader. - :param num_workers: How many subprocesses to use for data loading. - 0 means that the data will be loaded in the main process. - (default: 0). - :param pin_memory: If True, the data loader will copy Tensors into CUDA - pinned memory before returning them. Defaults to True. - :param kwargs: - :return: - """ - other_dataloader_args = {} - - if parse_version(torch.__version__) >= parse_version("1.7.0"): - other_dataloader_args["persistent_workers"] = persistent_workers - for k, v in kwargs.items(): - other_dataloader_args[k] = v - - collate_from_data_or_kwargs(self.adapted_dataset, other_dataloader_args) - self.dataloader = DataLoader( - self.adapted_dataset, - num_workers=num_workers, - batch_size=self.eval_mb_size, - pin_memory=pin_memory, - **other_dataloader_args - ) - - def forward(self): - """Compute the model's output given the current mini-batch.""" - return avalanche_forward(self.model, self.mb_x, self.mb_task_id) - - def model_adaptation(self, model=None): - """Adapts the model to the current data. - - Calls the :class:`~avalanche.models.DynamicModule`s adaptation. - """ - if model is None: - model = self.model - avalanche_model_adaptation(model, self.experience) - return model.to(self.device) - - def _unpack_minibatch(self): - """We assume mini-batches have the form . - This allows for arbitrary tensors between y and t. - Keep in mind that in the most general case mb_task_id is a tensor - which may contain different labels for each sample. - """ - assert len(self.mbatch) >= 3 - super()._unpack_minibatch() - - def eval_dataset_adaptation(self, **kwargs): - """Initialize `self.adapted_dataset`.""" - self.adapted_dataset = self.experience.dataset - self.adapted_dataset = self.adapted_dataset.eval() - - def make_optimizer(self): - """Optimizer initialization. - - Called before each training experiene to configure the optimizer. - """ - # we reset the optimizer's state after each experience. - # This allows to add new parameters (new heads) and - # freezing old units during the model's adaptation phase. - reset_optimizer(self.optimizer, self.model) - - ######################################################### - # Plugin Triggers # - ######################################################### - - def _before_train_dataset_adaptation(self, **kwargs): - trigger_plugins(self, "before_train_dataset_adaptation", **kwargs) - - def _after_train_dataset_adaptation(self, **kwargs): - trigger_plugins(self, "after_train_dataset_adaptation", **kwargs) - - def _before_eval_dataset_adaptation(self, **kwargs): - trigger_plugins(self, "before_eval_dataset_adaptation", **kwargs) - - def _after_eval_dataset_adaptation(self, **kwargs): - trigger_plugins(self, "after_eval_dataset_adaptation", **kwargs) diff --git a/avalanche/training/templates/update_type/__init__.py b/avalanche/training/templates/update_type/__init__.py new file mode 100644 index 000000000..3cc498524 --- /dev/null +++ b/avalanche/training/templates/update_type/__init__.py @@ -0,0 +1,5 @@ +"""Update types define how the model is updated for every batch of data. +""" + +from .sgd_update import SGDUpdate +from .meta_update import MetaUpdate diff --git a/avalanche/training/templates/update_type/meta_update.py b/avalanche/training/templates/update_type/meta_update.py new file mode 100644 index 000000000..d387db9c0 --- /dev/null +++ b/avalanche/training/templates/update_type/meta_update.py @@ -0,0 +1,51 @@ +from avalanche.training.utils import trigger_plugins + + +class MetaUpdate: + def training_epoch(self, **kwargs): + """Training epoch. + + :param kwargs: + :return: + """ + for self.mbatch in self.dataloader: + if self._stop_training: + break + + self._unpack_minibatch() + self._before_training_iteration(**kwargs) + + self.optimizer.zero_grad() + self.loss = 0 + + # Inner updates + self._before_inner_updates(**kwargs) + self._inner_updates(**kwargs) + self._after_inner_updates(**kwargs) + + # Outer update + self._before_outer_update(**kwargs) + self._outer_update(**kwargs) + self._after_outer_update(**kwargs) + + self.mb_output = self.forward() + + self._after_training_iteration(**kwargs) + + def _before_inner_updates(self, **kwargs): + trigger_plugins(self, "before_inner_updates", **kwargs) + + def _inner_updates(self, **kwargs): + raise NotImplementedError() + + def _after_inner_updates(self, **kwargs): + trigger_plugins(self, "after_inner_updates", **kwargs) + + def _before_outer_update(self, **kwargs): + trigger_plugins(self, "before_outer_update", **kwargs) + + def _outer_update(self, **kwargs): + raise NotImplementedError() + + def _after_outer_update(self, **kwargs): + trigger_plugins(self, "after_outer_update", **kwargs) diff --git a/avalanche/training/templates/update_type/sgd_update.py b/avalanche/training/templates/update_type/sgd_update.py new file mode 100644 index 000000000..d85365f49 --- /dev/null +++ b/avalanche/training/templates/update_type/sgd_update.py @@ -0,0 +1,36 @@ + +class SGDUpdate: + def training_epoch(self, **kwargs): + """Training epoch. + + :param kwargs: + :return: + """ + for self.mbatch in self.dataloader: + if self._stop_training: + break + + self._unpack_minibatch() + self._before_training_iteration(**kwargs) + + self.optimizer.zero_grad() + self.loss = 0 + + # Forward + self._before_forward(**kwargs) + self.mb_output = self.forward() + self._after_forward(**kwargs) + + # Loss & Backward + self.loss += self.criterion() + + self._before_backward(**kwargs) + self.backward() + self._after_backward(**kwargs) + + # Optimization step + self._before_update(**kwargs) + self.optimizer_step() + self._after_update(**kwargs) + + self._after_training_iteration(**kwargs) diff --git a/examples/naive.py b/examples/naive.py new file mode 100644 index 000000000..91e895dac --- /dev/null +++ b/examples/naive.py @@ -0,0 +1,59 @@ +import torch +from os.path import expanduser + +from avalanche.models import SimpleMLP +from avalanche.evaluation.metrics import ( + accuracy_metrics, + loss_metrics, +) +from avalanche.training.plugins import EvaluationPlugin +from avalanche.benchmarks.classic import SplitMNIST +from avalanche.logging import InteractiveLogger +from avalanche.training.supervised import ( + Naive +) + + +def main(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + scenario = SplitMNIST( + n_experiences=5, + dataset_root=expanduser("~") + "/.avalanche/data/mnist/" + ) + + # choose some metrics and evaluation method + interactive_logger = InteractiveLogger() + eval_plugin = EvaluationPlugin( + accuracy_metrics( + minibatch=True, epoch=True, experience=True, stream=True + ), + loss_metrics(minibatch=True, epoch=True, experience=True, stream=True), + loggers=[interactive_logger], + ) + + model = SimpleMLP(hidden_size=128) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + criterion = torch.nn.CrossEntropyLoss() + + # create strategy + strategy = Naive( + model, + optimizer, + criterion, + train_epochs=1, + device=device, + train_mb_size=32, + evaluator=eval_plugin, + ) + + # train on the selected scenario with the chosen strategy + for experience in scenario.train_stream: + print("Start training on experience ", experience.current_experience) + strategy.train(experience) + strategy.eval(scenario.test_stream[:]) + + +if __name__ == "__main__": + main() diff --git a/examples/online_naive.py b/examples/online_naive.py index c9c714e37..d4b7b581f 100644 --- a/examples/online_naive.py +++ b/examples/online_naive.py @@ -116,14 +116,14 @@ def main(args): # ocl_benchmark = OnlineCLScenario(batch_streams) for i, exp in enumerate(scenario.train_stream): # Create online scenario from experience exp - ocl_benchmark = OnlineCLScenario( - original_streams=batch_streams, - experiences=exp, - experience_size=1, - access_task_boundaries=True, - ) + ocl_benchmark = OnlineCLScenario(original_streams=batch_streams, + experiences=exp, + experience_size=1, + access_task_boundaries=True) + # Train on the online train stream of the scenario cl_strategy.train(ocl_benchmark.train_stream) + results.append(cl_strategy.eval(scenario.original_test_stream)) diff --git a/examples/online_replay.py b/examples/online_replay.py index 8ee61de7e..cf6750dd9 100644 --- a/examples/online_replay.py +++ b/examples/online_replay.py @@ -128,7 +128,7 @@ def main(args): original_streams=batch_streams, experiences=exp, experience_size=1 ) # Train on the online train stream of the scenario - cl_strategy.train(ocl_benchmark.online_train_stream) + cl_strategy.train(ocl_benchmark.train_stream) results.append(cl_strategy.eval(scenario.test_stream)) diff --git a/tests/training/test_online_strategies.py b/tests/training/test_online_strategies.py index b02819810..fdced8935 100644 --- a/tests/training/test_online_strategies.py +++ b/tests/training/test_online_strategies.py @@ -10,6 +10,7 @@ from avalanche.benchmarks.scenarios.online_scenario import OnlineCLScenario from avalanche.training import OnlineNaive from tests.unit_tests_utils import get_fast_benchmark +from avalanche.training.plugins.evaluation import default_evaluator class StrategyTest(unittest.TestCase): @@ -52,10 +53,10 @@ def test_naive(self): train_mb_size=1, device=self.device, eval_mb_size=50, + evaluator=default_evaluator(), ) - ocl_benchmark = OnlineCLScenario( - benchmark_streams, access_task_boundaries=True - ) + ocl_benchmark = OnlineCLScenario(benchmark_streams, + access_task_boundaries=True) self.run_strategy_boundaries(ocl_benchmark, strategy) # Without task boundaries @@ -67,10 +68,10 @@ def test_naive(self): train_mb_size=1, device=self.device, eval_mb_size=50, + evaluator=default_evaluator(), ) - ocl_benchmark = OnlineCLScenario( - benchmark_streams, access_task_boundaries=False - ) + ocl_benchmark = OnlineCLScenario(benchmark_streams, + access_task_boundaries=False) self.run_strategy_no_boundaries(ocl_benchmark, strategy) def load_benchmark(self, use_task_labels=False): diff --git a/tests/training/test_replay.py b/tests/training/test_replay.py index 4088b4136..ad7dbf359 100644 --- a/tests/training/test_replay.py +++ b/tests/training/test_replay.py @@ -22,7 +22,7 @@ ParametricBuffer, ) from avalanche.training.supervised import Naive -from avalanche.training.templates.supervised import SupervisedTemplate +from avalanche.training.templates import SupervisedTemplate from tests.unit_tests_utils import get_fast_benchmark diff --git a/tests/training/test_strategies.py b/tests/training/test_strategies.py index 70cd193f3..328658973 100644 --- a/tests/training/test_strategies.py +++ b/tests/training/test_strategies.py @@ -47,7 +47,7 @@ from avalanche.training.supervised.icarl import ICaRL from avalanche.training.supervised.joint_training import AlreadyTrainedError from avalanche.training.supervised.strategy_wrappers import PNNStrategy -from avalanche.training.templates.supervised import SupervisedTemplate +from avalanche.training.templates import SupervisedTemplate from avalanche.training.utils import get_last_fc_layer from tests.unit_tests_utils import get_fast_benchmark, get_device