Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New mixin-based template structure #1092

Merged
merged 16 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 226 additions & 0 deletions avalanche/NEW_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
from abc import ABC
from typing import TypeVar, Generic
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from avalanche.training.templates.base import BaseTemplate

CallbackResult = TypeVar("CallbackResult")
Template = TypeVar("Template", bound="BaseTemplate")


class BasePlugin(Generic[Template], ABC):
"""ABC for BaseTemplate plugins.

A plugin is simply an object implementing some strategy callbacks.
Plugins are called automatically during the strategy execution.

Callbacks provide access before/after each phase of the execution.
In general, for each method of the training and evaluation loops,
`StrategyCallbacks`
provide two functions `before_{method}` and `after_{method}`, called
before and after the method, respectively.
Therefore plugins can "inject" additional code by implementing callbacks.
Each callback has a `strategy` argument that gives access to the state.

In Avalanche, callbacks are used to implement continual strategies, metrics
and loggers.
"""

def __init__(self):
pass

def before_training(self, strategy: Template, *args, **kwargs):
"""Called before `train` by the `BaseTemplate`."""
pass

def before_training_exp(self, strategy: Template, *args, **kwargs):
"""Called before `train_exp` by the `BaseTemplate`."""
pass

def after_training_exp(self, strategy: Template, *args, **kwargs):
"""Called after `train_exp` by the `BaseTemplate`."""
pass

def after_training(self, strategy: Template, *args, **kwargs):
"""Called after `train` by the `BaseTemplate`."""
pass

def before_eval(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `eval` by the `BaseTemplate`."""
pass

def before_eval_exp(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `eval_exp` by the `BaseTemplate`."""
pass

def after_eval_exp(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `eval_exp` by the `BaseTemplate`."""
pass

def after_eval(self, strategy: Template, *args, **kwargs) -> CallbackResult:
"""Called after `eval` by the `BaseTemplate`."""
pass


class BaseSGDPlugin(BasePlugin[Template], ABC):
"""ABC for BaseSGDTemplate plugins.

See `BaseSGDTemplate` for complete description of the train/eval loop.
"""

def before_training_epoch(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `train_epoch` by the `BaseTemplate`."""
pass

def before_training_iteration(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before the start of a training iteration by the
`BaseTemplate`."""
pass

def before_forward(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `model.forward()` by the `BaseTemplate`."""
pass

def after_forward(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `model.forward()` by the `BaseTemplate`."""
pass

def before_backward(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `criterion.backward()` by the `BaseTemplate`."""
pass

def after_backward(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `criterion.backward()` by the `BaseTemplate`."""
pass

def after_training_iteration(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after the end of a training iteration by the
`BaseTemplate`."""
pass

def before_update(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `optimizer.update()` by the `BaseTemplate`."""
pass

def after_update(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `optimizer.update()` by the `BaseTemplate`."""
pass

def after_training_epoch(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `train_epoch` by the `BaseTemplate`."""
pass

def before_eval_iteration(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before the start of a training iteration by the
`BaseTemplate`."""
pass

def before_eval_forward(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `model.forward()` by the `BaseTemplate`."""
pass

def after_eval_forward(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `model.forward()` by the `BaseTemplate`."""
pass

def after_eval_iteration(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after the end of an iteration by the
`BaseTemplate`."""
pass

def before_train_dataset_adaptation(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `train_dataset_adapatation` by the `BaseTemplate`."""
pass

def after_train_dataset_adaptation(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `train_dataset_adapatation` by the `BaseTemplate`."""
pass

def before_eval_dataset_adaptation(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `eval_dataset_adaptation` by the `BaseTemplate`."""
pass

def after_eval_dataset_adaptation(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called after `eval_dataset_adaptation` by the `BaseTemplate`."""
pass

# ====================================================================> NEW

def before_inner_updates(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have inner/outer updates here? SGD doesn't have an inner/outer loop

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I discussed previously. The new base SGD template is supposed to be the base class for all possible sorts of SGD-based strategies, and the meta-learning template is one of them (as defined in common_tamplates.py). So the idea is to add all callbacks that can be triggered by any of the SGD-based templates. Or are you suggesting splitting it into a separate plugin class for meta-learning-based strategies?

self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `_inner_updates` by the `BaseTemplate`."""
pass

def inner_updates(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not a before/after method, shouldn't be here.

Copy link
Collaborator Author

@HamedHemati HamedHemati Aug 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, but the reason I added the inner_updates callback as a plugin trigger is that, unlike supervised strategies, we don't have any "Naive" type of updates that I can set as default for meta-learning-based strategies.
More precisely, in supervised strategies, we have the training_epoch function that is implemented in its most basic form (which is Naive fine-tuning), and you can augment it by adding new plugins. We don't have such a general structure similar totraining_epoch for inner updates in meta-learning, and it can be completely different from method to method. That's why I added it as a plugin trigger that has to be implemented by the user. Do you have other suggestions?

self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `_inner_updates` by the `BaseTemplate`."""
pass

def after_inner_updates(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `_outer_updates` by the `BaseTemplate`."""
pass

def before_outer_update(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `_outer_updates` by the `BaseTemplate`."""
pass

def outer_update(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as inner_update. Shouldn't be here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same answer as the previous one.

self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `_inner_updates` by the `BaseTemplate`."""
pass

def after_outer_update(
self, strategy: Template, *args, **kwargs
) -> CallbackResult:
"""Called before `_outer_updates` by the `BaseTemplate`."""
pass
Loading