In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from online_continuum import OnlineContinuum
from avalanche.benchmarks.utils import AvalancheDataset, AvalancheTensorDataset
from avalanche.benchmarks import nc_benchmark

from avalanche.logging import InteractiveLogger
from avalanche.training.plugins.evaluation import default_logger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.models import SimpleMLP
from torch.nn import CrossEntropyLoss
from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics, loss_metrics
from avalanche.training.plugins import AGEMPlugin
from avalanche.training.strategies import BaseStrategy
from avalanche.training.plugins.strategy_plugin import StrategyPlugin


from torch.optim import Optimizer, SGD
from torch.nn import CrossEntropyLoss,Module
from torch.optim import Adam
from typing import Optional, Sequence, Union, List


from models import *

In [None]:
class cbrsPLUGIN(AGEMPlugin):

    def __init__(self, patterns_per_experience: int, sample_size: int):

        super().__init__(patterns_per_experience, sample_size)

        self.full_class = set()
        self.class_cnt = {}
        self.buf_mem_max_size = 500

    @torch.no_grad()
    def update_memory(self, dataset):
        pass
    
    def sample(self, item, label, task, nc):
        
        if label in self.class_cnt.keys():
            self.class_cnt[label] +=1
        else:
            self.class_cnt[label] =1

        self.full_class.add(max(self.class_cnt, key=self.class_cnt.get))
        
        instance = [item, label, task]
        
        #if buffer size is smaller than max buffer size, then directly add the instance to buffer
        if len(self.buffers) < self.buf_mem_max_size:
            self.buffers.append(instance)
        else :
            # if label is not in full classes,found all the indices of largest class, pick a indices at random and
                    # overwrite the selected instance with the current one.
            if label not in self.full_class :
                
                largest_cls = max(self.class_cnt, key=self.class_cnt.get)
                largest_cls_idxs = np.where(np.array(self.buffers)[:, 1] == largest_cls)
                #choosing a largest class at random
                idx = random.choice(largest_cls_idxs)
                self.buffers[idx] = instance
            else :
                mc = sum(np.array(self.buffers)[:, 1] == label)
                        
                #sampling u uniformly from (0, 1)
                u = random.random()
                if u < mc/nc :
                    req_cls_idxs = np.where(np.array(self.buffers)[:, 1] == label)
                    #picked a stored instance of current class at random
                    idx = random.choice(req_cls_idxs)
                    #replace  with current instance
                    self.buffers[idx] = instance
    
    def before_training_iteration(self, strategy, **kwargs):
        if len(self.buffers) > 0:
            strategy.model.train()
            strategy.optimizer.zero_grad()
            
            #sample randomly from buffer
            random_samples = []
            if batch_size <= len(self.buffers):
                r_samples = random.sample(self.buffers,strategy.train_mb_size)
            
            replay_loss = 0
            xx, yy, task = [], [], []

            for r in r_samples:
                xx.append(r[0])
                yy.append(r[1])
                task.append(r[2])

            xref = torch.stack(xx).to(strategy.device)
            yref = torch.Tensor(yy).to(strategy.device).type(torch.int64)
            tref = torch.Tensor(task).to(strategy.device)
            
            #predicting the labels and calculating the loss
            r_pred = avalanche_forward(strategy.model, xref,tref)
            replay_loss = strategy._criterion(r_pred,yref)

            replay_loss.backward()
            self.reference_gradients = [
                    p.grad.view(-1) if p.grad is not None
                    else torch.zeros(p.numel(), device=strategy.device)
                    for n, p in strategy.model.named_parameters()]
            self.reference_gradients = torch.cat(self.reference_gradients)
            strategy.optimizer.zero_grad()

In [None]:
class CBRS(BaseStrategy):
    
    def __init__(self, model: Module, optimizer: Optimizer, criterion,
                 patterns_per_exp: int, sample_size: int = 64,
                 train_mb_size: int = 1, train_epochs: int = 1,
                 eval_mb_size: int = None, device=None,
                 plugins: Optional[List[StrategyPlugin]] = None,
                 evaluator: EvaluationPlugin = default_logger, eval_every=-1):
        
        super().__init__(
            model, optimizer, criterion,
            train_mb_size=train_mb_size, train_epochs=train_epochs,
            eval_mb_size=eval_mb_size, device=device, plugins=plugins,
            evaluator=evaluator, eval_every=eval_every)
        
        cbrs = cbrsPLUGIN(patterns_per_exp, sample_size)
        if plugins is None:
            plugins = [cbrs]
        else:
            plugins.append(cbrs)

        self.nC = set()
        self.nb = 1
        self.stream_instances_encountered = {}
        
    def training_epoch(self, **kwargs):
        
        for self.mbatch in self.dataloader:

            if self._stop_training:
                break

            self._unpack_minibatch()

            for lbl in self.mbatch[1].tolist():
                self.nC.add(lbl)
                
            a = 1/len(self.nC)

            self._before_training_iteration(**kwargs)

            self.optimizer.zero_grad()
            self.loss = 0

            self._before_forward(**kwargs)
            self.mb_output = self.forward()
            self._after_forward(**kwargs)

            self.loss += a*self.criterion()

            self._before_backward(**kwargs)
            self.loss.backward()
            self._after_backward(**kwargs)

            self._before_update(**kwargs)
            self.optimizer.step()
            self._after_update(**kwargs)

            self._after_training_iteration(**kwargs)

            for i in range(len(self.mbatch[1])) :
                index = int(self.mbatch[1][i])
                if index not in self.stream_instances_encountered :
                    self.stream_instances_encountered[index] = 1
                else :
                    self.stream_instances_encountered[index] += 1
                
                for p in self.plugins:
                    item = self.mbatch[0][i]
                    label = int(self.mbatch[1][i])
                    task = int(self.mbatch[2][i])
                    
                    if isinstance(p,cbrsPLUGIN):
                        p.update_mem(item, label, task,self.stream_instances_encountered[index])

In [None]:
continuum = OnlineContinuum('MNIST', transform=None)

max_imbalance = 2
steps = round(max_imbalance / 0.5) + 1
continuum.create_imbalances(max_imbalance, steps)

train_dataset = AvalancheTensorDataset(continuum.inputs, continuum.labels)
test_dataset = AvalancheTensorDataset(continuum.inputs1, continuum.labels1)

scenario = nc_benchmark(
train_dataset, test_dataset, n_experiences=10, seed=1234, task_labels=True,fixed_class_order=[0,1,2,3,4,5,6,7,8,9])


train_stream = scenario.train_stream

for experience in train_stream:
    t = experience.task_label
    exp_id = experience.current_experience
    training_dataset = experience.dataset

    print('Task {} batch {} -> train'.format(t, exp_id))
    print('This batch contains', len(training_dataset), 'patterns')

In [None]:
# MODEL CREATION
model = SimpleMLP(num_classes=scenario.n_classes)

# choose some metrics and evaluation method
interactive_logger = InteractiveLogger()

eval_plugin = EvaluationPlugin(
    accuracy_metrics(
        minibatch=True, epoch=True, experience=True, stream=True
    ),
    loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    forgetting_metrics(experience=True),
    loggers=[interactive_logger],
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

# CREATE THE STRATEGY INSTANCE (NAIVE)
cl_strategy = CBRS(
        model,
        optimizer=Adam(model.parameters()),
        patterns_per_exp=1000,
        criterion=CrossEntropyLoss(),
        train_mb_size=128,  # b - batch size
        train_epochs=10,
        eval_mb_size=128,
        evaluator=eval_plugin,
        device=device,
    )

# TRAINING LOOP
print("Starting experiment...")
results = []
for experience in scenario.train_stream:
    print("Start of experience ", experience.current_experience)
    cl_strategy.train(experience)
    print("Training completed")


    print("Computing accuracy on the whole test set")
    results.append(cl_strategy.eval(scenario.test_stream))