# Code


In [None]:
# Imports

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision
import numpy as np
from torch.utils.data import Subset, TensorDataset, DataLoader
from random import shuffle
from tqdm import tqdm

if torch.cuda.is_available():
  device = "cuda:0"
else:
  device = "cpu"

device

## Utilities

In [None]:
# data

def get_random_split_array(num_classes : int, num_tasks: int) -> dict():
  """
  Given a number of classes and a number of tasks,
  returns a valid split of the given classes into the given
  number of tasks.
  """

  assert num_classes > num_tasks, "Number of tasks must be equal or smaller than the number of classes"
  assert num_classes > 0, "Number of classes must be greater than zero"
  assert num_tasks > 0,  "Number of tasks must be greater than"

  classes = list(range(num_classes))
  shuffle(classes)
  class_split = {str(i): classes[i*2: (i+1)*2] for i in range(num_tasks)}

  return class_split

def get_split_dataset(dataset : torch.utils.data.Dataset,
                      split: torch.Tensor,
                      task_incremental: bool = False) -> dict():
  """
  Accepts a torch dataset and a 2D tensor describing how to split
  the original datset in the described tasks. For example, the split
  [[3,5],[4,6]] specifies classes 3 and 5 for task 1 and classes 4 and 6
  for taks 2. Returns a dictionary containing torch.Subset instances, remapped
  if task_incremental is true
  """
  split_dataset = {}

  if task_incremental is True:
    raise NotImplementedError("This method does not support remapping yet")

  for e, current_classes in split.items():
      task_indices = np.isin(np.array(dataset.targets), current_classes)
      split_dataset[e] = Subset(dataset, np.where(task_indices)[0])
  return split_dataset


def get_dataloaders(split_dataset : dict(), batch_size: int, shuffle: bool = True) -> dict():
  """
  Given a split dataset (Subset, Task) returns a dictionary containing
  all the dataloaders needed
  """
  loaders = {}
  for task, data in split_dataset.items():
      loaders[task] = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=shuffle)

  return loaders

# greedy buffer

class GreedyBuffer:
    def __init__(self, samples_per_class):
        self.samples_per_class = samples_per_class
        self.samples = torch.Tensor([])
        self.targets = torch.Tensor([])

    def store_data(self, loader):
        samples, targets = torch.Tensor([]), torch.Tensor([])
        for sample, target in loader:
            samples = torch.cat((samples, sample))
            targets = torch.cat((targets, target))

        for label in torch.unique(targets):
            greedy_idx = torch.where(targets == label)[0][:self.samples_per_class]
            self.samples = torch.cat((self.samples, samples[greedy_idx]))
            self.targets = torch.cat((self.targets, targets[greedy_idx]))

    def get_data(self):
        return self.samples, self.targets.to(torch.int64)

    def __len__(self):
        assert len(self.samples) == len(self.targets), f"Incosistent lengths of data tensor: {self.samples.shape}, target tensor: {self.targets.shape}!"
        return len(self.samples)

In [None]:
# metrics

def compute_backward_transfer(array : np.array):
    """
    Given a two dimensional array representing the accuracy matrix
    T, where T_ij is the model trained on the previous 0...i tasks
    and evaluated on the j-th, computes the backward transfer.
    """
    num_tasks = array.shape[0]
    diag = np.diag(array)[:-1] # Note, we do not compute backward transfer for the last task!
    end_acc = array[:-1, -1]
    bwt = np.sum(end_acc - diag)/(num_tasks - 1)
    return bwt


def compute_forward_transfer(array, b):
    """
    Given a two dimensional array representing the accuracy matrix
    T, where T_ij is the model trained on the previous 0...i tasks
    and evaluated on the j-th, computes the forward transfer.
    """
    num_tasks = array.shape[0]
    sub_diag = np.diag(array, k=-1) # Note, we do not compute forward transfer for the first task!
    fwt = np.sum(sub_diag - b[1:])/(num_tasks - 1)
    return fwt

def compute_average_accuracy(array):
    num_tasks = len(array)
    avg_acc = np.sum(array[:, -1], axis=0)/num_tasks
    return avg_acc

In [None]:
# plotting

def dict2array(acc, device):
    num_tasks = len(acc)
    first_task = list(acc.keys())[0]
    sequence_length = len(acc[first_task]) if isinstance(acc[first_task], list) else num_tasks
    acc_array = np.zeros((num_tasks, sequence_length))
    for task, val in acc.items():
        if device != "cpu":
          val = [x.cpu().numpy() for x in val]

        acc_array[int(task), :] = val
    return acc_array


def plot_accuracy_matrix(array):
    num_tasks = array.shape[1]
    array = np.round(array, 2)
    fig, ax = plt.subplots()
    ax.imshow(array, vmin=np.min(array), vmax=np.max(array))
    for i in range(len(array)):
        for j in range(array.shape[1]):
            ax.text(j,i, array[i,j], va='center', ha='center', c='w', fontsize=15)
    ax.set_yticks(np.arange(num_tasks))
    ax.set_ylabel('Number of tasks')
    ax.set_xticks(np.arange(num_tasks))
    ax.set_xlabel('Tasks finished')
    ax.set_title(f"ACC: {np.mean(array[:, -1]):.3f} -- std {np.std(np.mean(array[:, -1])):.3f}")
    plt.show()


def plot_acc_over_time(array):
    fig, ax = plt.subplots()
    for e, acc in enumerate(array):
        ax.plot(acc, label=e)
    plt.legend()
    plt.show()

## Data and split


In [None]:
mean, std = (0.1307), (0.3081)

transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=mean, std=std)
        ])

train_dataset = torchvision.datasets.MNIST(root=".", train=True, download=True, transform=transforms) # TODO: add transforms
val_dataset = torchvision.datasets.MNIST(root=".", train=False, download=True, transform=transforms)

# define number of classes and tasks
num_classes = len(train_dataset.classes)
#@markdown Please ensure that number of tasks is lower (or equal) than the number of classes
num_tasks = 5 # @param{type:"slider", min:1, max:10, step:1}
train_bs = 128 # @param{type:"integer", min:2, max:128}
val_bs = 128 # @param{type:"integer", min:2, max:128}

train_split = get_random_split_array(num_classes=num_classes, num_tasks=num_tasks)
train_dataset_split = get_split_dataset(train_dataset, train_split)
train_loaders = get_dataloaders(train_dataset_split, batch_size=train_bs, shuffle=True)

val_dataset_split = get_split_dataset(val_dataset, train_split)
val_loaders = get_dataloaders(val_dataset_split, batch_size=val_bs, shuffle=False)

print(f"Using the following splits: \n training: {train_split} \n validation: {train_split}")


## Models


In [None]:
# MLP
class MLP(torch.nn.Module):
    def __init__(self, args):
        super().__init__()
        hidden_size = args['hidden_size']
        self.fc1 = torch.nn.Linear(args['in_size']**2 * args['n_channels'], hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc3 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc4 = torch.nn.Linear(hidden_size, args['num_classes'])

    def forward(self, input):
        x = input.flatten(start_dim=1)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = self.fc4(x)
        return x


# Conditional VAE (supports very basic conditioning on class labels)
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim + num_classes, hidden_dim)
        self.fc1_bis = nn.Linear(hidden_dim, hidden_dim*2)
        self.fc2_mean = nn.Linear(hidden_dim*2, latent_dim)
        self.fc2_logvar = nn.Linear(hidden_dim*2, latent_dim)

    def forward(self, x, labels):
        x = torch.cat([x, labels], dim=1)
        h = torch.relu(self.fc1(x))
        h = torch.relu(self.fc1_bis(h))
        mean = self.fc2_mean(h)
        logvar = self.fc2_logvar(h)
        return mean, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim, num_classes):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim + num_classes, hidden_dim)
        self.fc1_bis = nn.Linear(hidden_dim, hidden_dim*2)
        self.fc2 = nn.Linear(hidden_dim*2, output_dim)

    def forward(self, z, labels):
        z = torch.cat([z, labels], dim=1)
        h = torch.relu(self.fc1(z))
        h = torch.relu(self.fc1_bis(h))
        x_reconstructed = torch.sigmoid(self.fc2(h))
        return x_reconstructed

class CVAE(nn.Module):
    def __init__(self, args):
        super(CVAE, self).__init__()

        input_dim = args["input_dim"]
        hidden_dim = args["hidden_dim"]
        latent_dim = args["latent_dim"]
        num_classes = args["num_classes"]

        self.encoder = Encoder(input_dim, hidden_dim, latent_dim, num_classes)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim, num_classes)
        self.num_classes = num_classes

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def forward(self, x, labels):
        mean, logvar = self.encoder(x, labels)
        z = self.reparameterize(mean, logvar)
        x_reconstructed = self.decoder(z, labels)
        return x_reconstructed, mean, logvar

    def sample(self, num_samples, latent_dim, class_label):
        z = torch.randn(num_samples, latent_dim).to('cuda')
        labels = torch.zeros(num_samples, self.num_classes).to('cuda')
        labels[:, class_label] = 1
        samples = self.decoder(z, labels)
        return samples

def vae_loss(x_reconstructed, x, mean, logvar):
    # Reconstruction loss
    reconstruction_loss = nn.functional.mse_loss(x_reconstructed, x, reduction="sum")

    # KL divergence loss
    kl_divergence = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())

    # Total loss
    return reconstruction_loss + kl_divergence

In [None]:
from copy import deepcopy

# agents
class Agent:
  """
  Super class defining a CL agent that
  implements the basic utilities. Each subclass
  must implement its own train and validation.
  """
  def __init__(self,
               criterion: torch.nn.modules.loss._Loss,
               tasks: list,
               training_args: dict(),
               model_args: dict()):
    self.criterion = criterion
    self.tasks = tasks
    self.training_args = training_args
    self.model_args = model_args
    self.model = MLP(model_args["MLP"]).to(device)
    self.optimizer = None

    if self.training_args["use_buffer"] > 0:
      self.buffer = GreedyBuffer(samples_per_class=self.training_args["use_buffer"])
    else:
      self.buffer = None

    if self.training_args["generative"] > 0:
      self.generative_replay = True
      self.ConditionalVAE = CVAE(self.model_args["CVAE"]).to(device)
    else:
      self.generative_replay = False
      self.ConditionalVAE = None

    # Note that tasks should be a list of integers for ease of use
    self.acc_dict = {key: [] for key in tasks} #dictionary storing the accuracy of the model on all tasks on current iteration
    self.acc_end_dict = {key: [] for key in tasks} #dictionary storing the accuracy of the model on all tasks so far, measured at the end

  def reset_accuracy(self):
    self.acc_dict = {key: [] for key in self.tasks}
    self.acc_end_dict = {key: [] for key in self.tasks}

  def train_VAE(self, loader, current_task):
    """
    Starts or continues to train the CVAE to act as
    a replay buffer. Returns a buffer with samples
    generated from the classes seen so far.
    """

    #training
    #NOTE: should optimizer be created every time? dunno
    optimizer = torch.optim.AdamW(self.ConditionalVAE.parameters(), lr=1e-3)

    self.ConditionalVAE.train()
    for epoch in range(self.model_args["CVAE"]["training_epochs"][current_task]):
        total_loss = 0
        for batch_idx, (data, labels) in enumerate(loader):
            data = data.view(-1, self.model_args["CVAE"]["input_dim"]).to(device)
            labels = torch.nn.functional.one_hot(labels, num_classes).float().to(device)
            optimizer.zero_grad()
            x_reconstructed, mean, logvar = self.ConditionalVAE(data, labels)
            loss = vae_loss(x_reconstructed, data, mean, logvar)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(loader.dataset)
        if epoch % 5 == 0:
          print(f'Epoch [{epoch+1}/{self.model_args["CVAE"]["training_epochs"][current_task]}], Loss: {avg_loss:.4f}')

    #sampling
    sampled_dataset = torch.tensor([]).to(device)
    sampled_labels = torch.tensor([]).to(device)
    num_samples = self.training_args["generative"]
    latent_dim = self.model_args["CVAE"]["latent_dim"]
    available_tasks = list(range(int(current_task)+1))
    available_labels = []

    for task in available_tasks:
      # NOTE/TODO: the split should have been passed as an arg, i know
      # this is bad code, apologies
      available_labels.append(train_split[str(task)][0])
      available_labels.append(train_split[str(task)][1])

    # generate all samples
    with torch.no_grad():
        for label in available_labels:
          # generate samples for each class seen so far
          samples = self.ConditionalVAE.sample(num_samples, latent_dim, label)

          # Reshape the samples before adding to dataset
          samples = samples.view(num_samples, 1, 28, 28)
          sampled_dataset = torch.cat((sampled_dataset, samples), dim=0)

          labels = torch.tensor([int(label)]).repeat(num_samples)
          labels = labels.to(device)
          sampled_labels = torch.cat((sampled_labels, labels), dim=0)

    return sampled_dataset, sampled_labels.to(torch.int64)

class CIA_Agent(Agent):
  def __init__(self, criterion, tasks, training_args, model_args):
    super().__init__(criterion, tasks, training_args, model_args)
    #add additional CI init steps here

  def _shared_step(self, loader : torch.utils.data.DataLoader, val_loaders: dict() = None, train: bool = True):
    epoch_loss, total, correct = 0, 0, 0

    for e, (X, y) in enumerate(loader):
      X, y = X.to(device), y.to(device)
      output = self.model(X)
      loss = self.criterion(output, y)

      if train:
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

      epoch_loss += loss.item()
      correct += torch.sum(torch.topk(output, axis=1, k=1)[1].squeeze(1) == y)
      total += len(X)

      # mid epoch validation
      # not executed when the validation shared step is called
      if e % 50 == 0 and train:
        self.validate(val_loaders, end_of_epoch = False)

    # I guess e is still visible after the for, uh?
    return epoch_loss / e, total, correct

  def train(self, train_loaders : dict(), val_loaders : dict()):
    for task, loader in train_loaders.items():
      print(f"Currently on task {task} \n")
      print("Resetting model and optimizer \n")
      self.model = MLP(self.model_args["MLP"]).to(device)
      self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)

      if self.buffer is not None:
        self.buffer.store_data(loader)
        print(f"Buffer stores {len(self.buffer)} samples")
        samples, targets = self.buffer.get_data()
        greedy_dataset = TensorDataset(samples, targets)
        #overwrite loader with new greedy dataset, consisting only of greedily accumulated samples
        greedy_loader = DataLoader(greedy_dataset, batch_size=loader.batch_size, shuffle=True)
        loader = greedy_loader
      elif self.generative_replay:
        #first train VAE
        print("Training VAE \n")
        gen_dataset, gen_labels = self.train_VAE(loader, task)
        gen_data = TensorDataset(gen_dataset, gen_labels)
        #overwrite loader with the samples generated by the CVAE
        gen_loader = DataLoader(gen_data, batch_size=loader.batch_size, shuffle=True)
        loader = gen_loader
        print(f"Using VAE samples, new size {gen_dataset.shape[0]}, labels {torch.unique(gen_labels)}")

      print("Training MLP on samples")
      for epoch in range(self.training_args['epochs'][task]):
        epoch_loss, total, correct = self._shared_step(loader, val_loaders, train=True)
        print(f"Epoch {epoch}: Loss {epoch_loss:.3f} Acc: {correct/total:.3f}")

      print(f"Evaluating after task {task} \n")
      self.validate(val_loaders, end_of_epoch=True)

  @torch.no_grad()
  def validate(self, val_loaders, end_of_epoch=False):
      self.model.eval()

      for task, loader in val_loaders.items():
        _, total, correct = self._shared_step(loader, None, train=False)
        self.acc_dict[task].append(correct/total)

        if end_of_epoch:
          (self.acc_end_dict[task]).append(correct/total)

      self.model.train()

## Training


In [None]:
# TODO: convert to omegadict
# train args
training_args = {
    "epochs": {
        "0": 10,
        "1": 10,
        "2": 10,
        "3": 10,
        "4": 10,
        },
    "use_buffer": 0, #if set to anything more than 0, will use a greedy buffer for training
    "generative": 300, #if set to anything more than 0, will use a VAE generated buffer for training
}

# model args
model_args = {
    "MLP" : {
      'in_size': 28,
      'n_channels': 1,
      'hidden_size': 50,
      'num_classes' : num_classes
    },
    "CVAE" : {
        "input_dim" : 28 * 28,
        "hidden_dim" : 2048, #1024
        "latent_dim" : 256, # 128
        "num_classes" : 10,
        "training_epochs": {
            "0": 30,
            "1": 30,
            "2": 30,
            "3": 30,
            "4": 30,
        },
    }
}


In [None]:
criterion = torch.nn.CrossEntropyLoss()
tasks = list(train_split.keys())

# Create the agent & initialize the network
agent = CIA_Agent(criterion=criterion, tasks=tasks, training_args=training_args, model_args=model_args)

# Check & save (for the FWT metric) the accuracy of randomly initialized model
agent.validate(val_loaders)
random_model_acc = [i[0] for i in agent.acc_dict.values()]
agent.reset_accuracy()

# Train the agent on the whole sequence of tasks
agent.train(train_loaders, val_loaders)

In [None]:
# Get accuracy of the agent at the end of each task
acc_at_end_arr = dict2array(agent.acc_end_dict, device=device)
plot_accuracy_matrix(acc_at_end_arr)

# Get intermediate accuracy
acc_arr = dict2array(agent.acc_dict, device=device)
plot_acc_over_time(acc_arr)

# move random model acc to cpu
if random_model_acc[0].device != "cpu":
  random_model_acc = [x.cpu().numpy() for x in random_model_acc]

print(f"The average accuracy at the end of sequence is: {compute_average_accuracy(acc_at_end_arr):.3f}")
print(f"BWT:'{compute_backward_transfer(acc_at_end_arr):.3f}'")
print(f"FWT:'{compute_forward_transfer(acc_at_end_arr, random_model_acc):.3f}'")