Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New mixin-based template structure #1092

Merged
merged 16 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 30 additions & 0 deletions avalanche/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,33 @@ def after_eval_dataset_adaptation(
) -> CallbackResult:
"""Called after `eval_dataset_adaptation` by the `BaseTemplate`."""
pass


class SupervisedMetaLearningPlugin(SupervisedPlugin[Template], ABC):
"""ABC for SupervisedMetaLearningTemplate plugins.

See `BaseTemplate` for complete description of the train/eval loop.
"""
def before_inner_updates(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `_inner_updates` by the `BaseTemplate`."""
pass

def after_inner_updates(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `_outer_updates` by the `BaseTemplate`."""
pass

def before_outer_update(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `_outer_updates` by the `BaseTemplate`."""
pass

def after_outer_update(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `_outer_updates` by the `BaseTemplate`."""
pass
2 changes: 1 addition & 1 deletion avalanche/evaluation/metric_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

if TYPE_CHECKING:
from .metric_results import MetricResult
from ..training.templates.supervised import SupervisedTemplate
from ..training.templates import SupervisedTemplate

TResult = TypeVar("TResult")
TAggregated = TypeVar("TAggregated", bound="PluginMetric")
Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torch import Tensor

if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate
from avalanche.benchmarks.scenarios import ClassificationExperience
from avalanche.evaluation import PluginMetric

Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from avalanche.evaluation.metric_utils import get_metric_name

if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class WeightCheckpoint(PluginMetric[Tensor]):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)

if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class ConfusionMatrix(Metric[Tensor]):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/forgetting_bwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)

if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class Forgetting(Metric[Union[float, None, Dict[int, float]]]):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/forward_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)

if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class ForwardTransfer(Metric[Union[float, None, Dict[int, float]]]):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/gpu_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from avalanche.evaluation.metric_results import MetricResult

if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class MaxGPU(Metric[float]):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/images_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate
from avalanche.benchmarks.utils import make_classification_dataset


Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/labels_repartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate
from avalanche.evaluation.metric_results import MetricResult


Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/mean_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from typing_extensions import Literal

if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate
from avalanche.evaluation.metric_results import MetricResult


Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/ram_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from avalanche.evaluation.metric_results import MetricResult

if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class MaxRAM(Metric[float]):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from avalanche.evaluation.metrics.mean import Mean

if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class ElapsedTime(Metric[float]):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/logging/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

if TYPE_CHECKING:
from avalanche.evaluation.metric_results import MetricValue
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class BaseLogger(ABC):
Expand Down
6 changes: 5 additions & 1 deletion avalanche/logging/interactive_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tqdm import tqdm

if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class InteractiveLogger(TextLogger, SupervisedPlugin):
Expand Down Expand Up @@ -61,6 +61,8 @@ def before_training_epoch(
metric_values: List["MetricValue"],
**kwargs
):
if isinstance(strategy.experience, OnlineCLExperience):
return
super().before_training_epoch(strategy, metric_values, **kwargs)
self._progress.total = len(strategy.dataloader)

Expand All @@ -70,6 +72,8 @@ def after_training_epoch(
metric_values: List["MetricValue"],
**kwargs
):
if isinstance(strategy.experience, OnlineCLExperience):
return
self._end_progress()
super().after_training_epoch(strategy, metric_values, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion avalanche/logging/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

if TYPE_CHECKING:
from avalanche.evaluation.metric_results import MetricValue
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class WandBLogger(BaseLogger, SupervisedPlugin):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/plugins/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
if TYPE_CHECKING:
from avalanche.evaluation import PluginMetric
from avalanche.logging import BaseLogger
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class EvaluationPlugin:
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/plugins/gdumb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from avalanche.training.storage_policy import ClassBalancedBuffer

if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class GDumbPlugin(SupervisedPlugin):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/plugins/gss_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from avalanche.training.plugins.strategy_plugin import SupervisedPlugin

if TYPE_CHECKING:
from ..templates.supervised import SupervisedTemplate
from ..templates import SupervisedTemplate


class GSS_greedyPlugin(SupervisedPlugin):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/plugins/lr_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import inspect

if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class LRSchedulerPlugin(SupervisedPlugin):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/plugins/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)

if TYPE_CHECKING:
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class ReplayPlugin(SupervisedPlugin):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/plugins/synaptic_intelligence.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from avalanche.training.utils import get_layers_and_params

if TYPE_CHECKING:
from ..templates.supervised import SupervisedTemplate
from ..templates import SupervisedTemplate

SynDataType = Dict[str, Dict[str, Tensor]]

Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/storage_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..benchmarks.utils.utils import concat_datasets

if TYPE_CHECKING:
from .templates.supervised import SupervisedTemplate
from .templates import SupervisedTemplate


class ExemplarsBuffer(ABC):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/ar1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SynapticIntelligencePlugin,
CWRStarPlugin,
)
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.utils import (
replace_bn_with_brn,
get_last_fc_layer,
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from avalanche.benchmarks.utils.utils import concat_datasets
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class Cumulative(SupervisedTemplate):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/deep_slda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from avalanche.training.plugins import SupervisedPlugin
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.models.dynamic_modules import MultiTaskModule
from avalanche.models import FeatureExtractorBackbone
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/icarl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from avalanche.training.plugins.strategy_plugin import SupervisedPlugin
from torch.nn import Module
from torch.utils.data import DataLoader
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate


class ICaRL(SupervisedTemplate):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/joint_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from avalanche.benchmarks.utils import concat_classification_datasets
from avalanche.benchmarks.utils.utils import concat_datasets
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.templates.supervised import SupervisedTemplate
from avalanche.training.templates import SupervisedTemplate
from avalanche.models import DynamicModule

if TYPE_CHECKING:
Expand Down