-
-
Notifications
You must be signed in to change notification settings - Fork 287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New mixin-based template structure #1092
Changes from 5 commits
136c69e
38e0cef
346017b
28d8156
6c5bfb0
23f7af1
02b9267
3e0c54a
50f2683
a66a6a3
5efeefc
1cd1924
aee1d2e
678a234
f00fd4d
6e6359f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
from abc import ABC | ||
from typing import TypeVar, Generic | ||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
from avalanche.training.templates.base import BaseTemplate | ||
|
||
CallbackResult = TypeVar("CallbackResult") | ||
Template = TypeVar("Template", bound="BaseTemplate") | ||
|
||
|
||
class BasePlugin(Generic[Template], ABC): | ||
"""ABC for BaseTemplate plugins. | ||
|
||
A plugin is simply an object implementing some strategy callbacks. | ||
Plugins are called automatically during the strategy execution. | ||
|
||
Callbacks provide access before/after each phase of the execution. | ||
In general, for each method of the training and evaluation loops, | ||
`StrategyCallbacks` | ||
provide two functions `before_{method}` and `after_{method}`, called | ||
before and after the method, respectively. | ||
Therefore plugins can "inject" additional code by implementing callbacks. | ||
Each callback has a `strategy` argument that gives access to the state. | ||
|
||
In Avalanche, callbacks are used to implement continual strategies, metrics | ||
and loggers. | ||
""" | ||
|
||
def __init__(self): | ||
pass | ||
|
||
def before_training(self, strategy: Template, *args, **kwargs): | ||
"""Called before `train` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def before_training_exp(self, strategy: Template, *args, **kwargs): | ||
"""Called before `train_exp` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_training_exp(self, strategy: Template, *args, **kwargs): | ||
"""Called after `train_exp` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_training(self, strategy: Template, *args, **kwargs): | ||
"""Called after `train` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def before_eval( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `eval` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def before_eval_exp( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `eval_exp` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_eval_exp( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called after `eval_exp` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_eval(self, strategy: Template, *args, **kwargs) -> CallbackResult: | ||
"""Called after `eval` by the `BaseTemplate`.""" | ||
pass | ||
|
||
|
||
class BaseSGDPlugin(BasePlugin[Template], ABC): | ||
"""ABC for BaseSGDTemplate plugins. | ||
|
||
See `BaseSGDTemplate` for complete description of the train/eval loop. | ||
""" | ||
|
||
def before_training_epoch( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `train_epoch` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def before_training_iteration( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before the start of a training iteration by the | ||
`BaseTemplate`.""" | ||
pass | ||
|
||
def before_forward( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `model.forward()` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_forward( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called after `model.forward()` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def before_backward( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `criterion.backward()` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_backward( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called after `criterion.backward()` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_training_iteration( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called after the end of a training iteration by the | ||
`BaseTemplate`.""" | ||
pass | ||
|
||
def before_update( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `optimizer.update()` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_update( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called after `optimizer.update()` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_training_epoch( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called after `train_epoch` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def before_eval_iteration( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before the start of a training iteration by the | ||
`BaseTemplate`.""" | ||
pass | ||
|
||
def before_eval_forward( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `model.forward()` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_eval_forward( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called after `model.forward()` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_eval_iteration( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called after the end of an iteration by the | ||
`BaseTemplate`.""" | ||
pass | ||
|
||
def before_train_dataset_adaptation( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `train_dataset_adapatation` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_train_dataset_adaptation( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called after `train_dataset_adapatation` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def before_eval_dataset_adaptation( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `eval_dataset_adaptation` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_eval_dataset_adaptation( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called after `eval_dataset_adaptation` by the `BaseTemplate`.""" | ||
pass | ||
|
||
# ====================================================================> NEW | ||
|
||
def before_inner_updates( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `_inner_updates` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def inner_updates( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not a before/after method, shouldn't be here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's true, but the reason I added the |
||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `_inner_updates` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_inner_updates( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `_outer_updates` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def before_outer_update( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `_outer_updates` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def outer_update( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as inner_update. Shouldn't be here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same answer as the previous one. |
||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `_inner_updates` by the `BaseTemplate`.""" | ||
pass | ||
|
||
def after_outer_update( | ||
self, strategy: Template, *args, **kwargs | ||
) -> CallbackResult: | ||
"""Called before `_outer_updates` by the `BaseTemplate`.""" | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we have inner/outer updates here? SGD doesn't have an inner/outer loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what I discussed previously. The new base SGD template is supposed to be the base class for all possible sorts of SGD-based strategies, and the meta-learning template is one of them (as defined in
common_tamplates.py
). So the idea is to add all callbacks that can be triggered by any of the SGD-based templates. Or are you suggesting splitting it into a separate plugin class for meta-learning-based strategies?