Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cumulative Strategy and tests implemented #130

Merged
merged 7 commits into from
Oct 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions avalanche/training/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .cl_naive import *
from .cl_cumulative import *
77 changes: 77 additions & 0 deletions avalanche/training/strategies/cl_cumulative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

################################################################################
# Copyright (c) 2020 ContinualAI Research #
# Copyrights licensed under the CC BY 4.0 License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 1-06-2020 #
# Author(s): Andrea Cossu #
# E-mail: contact@continualai.org #
# Website: clair.continualai.org #
################################################################################

from __future__ import absolute_import
from __future__ import division
# Python 2-3 compatible
from __future__ import print_function

from typing import Optional, Sequence

from torch.nn import Module
from torch.optim import Optimizer

from avalanche.evaluation import EvalProtocol
from avalanche.training.skeletons import TrainingFlow
from avalanche.training.strategies import Naive
from avalanche.benchmarks.scenarios.generic_definitions import IStepInfo
from avalanche.training.utils import ConcatDatasetWithTargets
from avalanche.training.skeletons import StrategySkeleton

class Cumulative(Naive):
"""
A Cumulative strategy in which, at each step (or task), the model
is trained with all the data encountered so far. Therefore, at each step,
the model is trained in a MultiTask scenario.
The strategy has a high memory and computational cost.
"""

def __init__(self, model: Module, classifier_field: str,
optimizer: Optimizer, criterion: Module,
train_mb_size: int = 1, train_epochs: int = 1,
test_mb_size: int = None, device=None,
evaluation_protocol: Optional[EvalProtocol] = None,
plugins: Optional[Sequence[StrategySkeleton]] = None):

super(Cumulative, self).__init__(model, classifier_field,
optimizer, criterion, train_mb_size, train_epochs,
test_mb_size, device, evaluation_protocol, plugins)


@TrainingFlow
def make_train_dataset(self, step_info: IStepInfo):
"""
Returns the training dataset, given the step_info instance.
The dataset is composed by all datasets encountered so far.

This is a part of the training flow. Sets the train_dataset namespace
value.

:param step_info: The step info instance, as returned from the CL
scenario.
:return: The training dataset.
"""

train_dataset = step_info.cumulative_training_sets()

train_dataset = ConcatDatasetWithTargets(
[ el[0] for el in train_dataset ]
)

self.update_namespace(train_dataset=train_dataset)
self.update_namespace(step_id=step_info.current_step)
return train_dataset


__all__ = ['Cumulative']
2 changes: 1 addition & 1 deletion avalanche/training/strategies/cl_naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Optional, Sequence

from torch.nn import Module
from torch.optim.optimizer import Optimizer
from torch.optim import Optimizer

from avalanche.evaluation import EvalProtocol
from avalanche.training.templates.deep_learning_strategy import \
Expand Down
122 changes: 122 additions & 0 deletions tests/test_strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

################################################################################
# Copyright (c) 2020 ContinualAI Research #
# Copyrights licensed under the CC BY 4.0 License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 1-06-2020 #
# Author(s): Andrea Cossu #
# E-mail: contact@continualai.org #
# Website: clair.continualai.org #
################################################################################

import unittest

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose
from torch.optim import SGD
from torch.nn import CrossEntropyLoss

from avalanche.extras.models import SimpleMLP
from avalanche.evaluation import EvalProtocol
from avalanche.evaluation.metrics import ACC
from avalanche.benchmarks.scenarios import \
create_nc_single_dataset_sit_scenario, DatasetPart, NCBatchInfo
from avalanche.training.strategies import Naive, Cumulative
from avalanche.training.plugins import ReplayPlugin

device = 'cpu'

class StrategyTest(unittest.TestCase):

def test_naive(self):
model = SimpleMLP()
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = CrossEntropyLoss()
mnist_train, mnist_test = self.load_dataset()
nc_scenario = create_nc_single_dataset_sit_scenario(
mnist_train, mnist_test, 5, shuffle=True, seed=1234)

eval_protocol = EvalProtocol(
metrics=[
ACC(num_class=nc_scenario.n_classes)
])

strategy = Naive(model, 'classifier', optimizer, criterion,
evaluation_protocol=eval_protocol, train_mb_size=100,
train_epochs=4, test_mb_size=100, device=device)

self.run_strategy(nc_scenario, strategy)


def test_replay(self):
model = SimpleMLP()
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = CrossEntropyLoss()
mnist_train, mnist_test = self.load_dataset()
nc_scenario = create_nc_single_dataset_sit_scenario(
mnist_train, mnist_test, 5, shuffle=True, seed=1234)

eval_protocol = EvalProtocol(
metrics=[
ACC(num_class=nc_scenario.n_classes)
])

strategy = Naive(model, 'classifier', optimizer, criterion,
evaluation_protocol=eval_protocol,
train_mb_size=100,
train_epochs=4, test_mb_size=100, device=device,
plugins=[ReplayPlugin(mem_size=10)])

self.run_strategy(nc_scenario, strategy)


def test_cumulative(self):
model = SimpleMLP()
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = CrossEntropyLoss()
mnist_train, mnist_test = self.load_dataset()
nc_scenario = create_nc_single_dataset_sit_scenario(
mnist_train, mnist_test, 5, shuffle=True, seed=1234)

eval_protocol = EvalProtocol(
metrics=[
ACC(num_class=nc_scenario.n_classes)
])

strategy = Cumulative(model, 'classifier', optimizer, criterion,
train_mb_size=100,
evaluation_protocol=eval_protocol,
train_epochs=4, test_mb_size=100, device=device)

self.run_strategy(nc_scenario, strategy)


def load_dataset(self):

mnist_train = MNIST('./data/mnist', train=True, download=True,
transform=Compose([ToTensor()]))
mnist_test = MNIST('./data/mnist', train=False, download=True,
transform=Compose([ToTensor()]))
return mnist_train, mnist_test

def run_strategy(self, scenario, cl_strategy):

print('Starting experiment...')
results = []
batch_info: NCBatchInfo
for batch_info in scenario:
print("Start of step ", batch_info.current_step)

cl_strategy.train(batch_info, num_workers=4)
print('Training completed')

print('Computing accuracy on the whole test set')
results.append(cl_strategy.test(batch_info, DatasetPart.COMPLETE,
num_workers=4))


if __name__ == '__main__':
unittest.main()