## MNIST Task Incremental Learning

In [8]:
import torch.nn as nn
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image
from typing import Any, Tuple

from typing import List
from copy import deepcopy

### Base CNN architecture

In [49]:
# define the base CNN
class SmallConv(nn.Module):
    """
    Small convolution network with no residual connections
    """
    def __init__(self, num_task=1, num_cls=10, channels=3,
                 avg_pool=2, lin_size=320):
        super(SmallConv, self).__init__()
        self.conv1 = nn.Conv2d(channels, 80, kernel_size=3, bias=False)
        self.conv2 = nn.Conv2d(80, 80, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(80)
        self.conv3 = nn.Conv2d(80, 80, kernel_size=3)
        self.bn3 = nn.BatchNorm2d(80)

        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(avg_pool)

        self.linsize = lin_size

        lin_layers = []
        for task in range(num_task):
            lin_layers.append(nn.Linear(self.linsize, num_cls)) # add fully connected layers for each task

        self.fc = nn.ModuleList(lin_layers) # holds the task specific FC 

        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x, tasks):

        x = self.conv1(x)
        x = self.maxpool(self.relu(x))

        x = self.conv2(x)
        x = self.maxpool(self.relu(self.bn2(x)))

        x = self.conv3(x)
        x = self.maxpool(self.relu(self.bn3(x)))
        x = x.view(-1, self.linsize)

        logits = self.fc[0](x) * 0 # get a zero-vector

        for idx, lin in enumerate(self.fc):
            task_idx = torch.nonzero((idx == tasks), as_tuple=False).view(-1) # select the training examples in the batch that belongs to the current task
            if len(task_idx) == 0: # if there are no training examples for the current task, continue
                continue

            task_out = torch.index_select(x, dim=0, index=task_idx) # obtain the training examples of the current task
            task_logit = lin(task_out) # task-specific FC layer
            logits.index_add_(0, task_idx, task_logit) # add the task-specific logits to the full-logits vector

        return logits


### Split MNIST Data-handling

In [None]:
class ModMNIST(torchvision.datasets.MNIST):
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

In [None]:
class SplitMNISTHandler:
    """
    Object for the SplitMNIST dataset
    """
    def __init__(self, tasks):
        mean_norm = [0.50]
        std_norm = [0.25]
        vanilla_transform = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize(mean=mean_norm, std=std_norm)])

        trainset = ModMNIST('/Users/ashwindesilva/research/modelzoo_continual/notebooks', download=True, train=True, transform=vanilla_transform)
        testset = ModMNIST('/Users/ashwindesilva/research/modelzoo_continual/notebooks', download=True, train=False, transform=vanilla_transform)

        tr_ind, te_ind = [], []
        tr_lab, te_lab = [], []
        for task_id, tsk in enumerate(tasks):
            for lab_id, lab in enumerate(tsk):

                task_tr_ind = np.where(np.isin(trainset.targets,
                                                [lab % 10]))[0]
                task_te_ind = np.where(np.isin(testset.targets,
                                                [lab % 10]))[0]

                tr_ind.append(task_tr_ind)
                te_ind.append(task_te_ind)
                curlab = (task_id, lab_id)

                tr_vals = [curlab for _ in range(len(task_tr_ind))]
                te_vals = [curlab for _ in range(len(task_te_ind))]

                tr_lab.append(tr_vals)
                te_lab.append(te_vals)

        tr_ind, te_ind = np.concatenate(tr_ind), np.concatenate(te_ind)
        tr_lab, te_lab = np.concatenate(tr_lab), np.concatenate(te_lab)

        trainset.data = trainset.data[tr_ind]
        testset.data = testset.data[te_ind]

        trainset.targets = [list(it) for it in tr_lab]
        testset.targets = [list(it) for it in te_lab]

        self.trainset = trainset
        self.testset = testset

    def get_data_loader(self, batch_size, train=True):
        def wif(id):
            """
            Used to fix randomization bug for pytorch dataloader + numpy
            Code from https://github.com/pytorch/pytorch/issues/5059
            """
            process_seed = torch.initial_seed()
            # Back out the base_seed so we can use all the bits.
            base_seed = process_seed - id
            ss = np.random.SeedSequence([id, base_seed])
            # More than 128 bits (4 32-bit words) would be overkill.
            np.random.seed(ss.generate_state(4))
        if train:
            data_loader = DataLoader(self.trainset, batch_size=batch_size, shuffle=True, worker_init_fn=wif, pin_memory=True, num_workers=4)
        else:
            data_loader = DataLoader(self.testset, batch_size=batch_size, shuffle=False, worker_init_fn=wif, pin_memory=True, num_workers=4)
        return data_loader

    def get_task_data_loader(self, task, batch_size, train=False):
        """
        Get Dataloader for a specific task
        """
        def wif(id):
            """
            Used to fix randomization bug for pytorch dataloader + numpy
            Code from https://github.com/pytorch/pytorch/issues/5059
            """
            process_seed = torch.initial_seed()
            # Back out the base_seed so we can use all the bits.
            base_seed = process_seed - id
            ss = np.random.SeedSequence([id, base_seed])
            # More than 128 bits (4 32-bit words) would be overkill.
            np.random.seed(ss.generate_state(4))
        if train:
            task_set = deepcopy(self.trainset)
        else:
            task_set = deepcopy(self.testset)

        task_ind = [task == i[0] for i in task_set.targets]

        task_set.data = task_set.data[task_ind]
        task_set.targets = np.array(task_set.targets)[task_ind, :]
        task_set.targets = [(lab[0], lab[1]) for lab in task_set.targets]

        loader = DataLoader(
            task_set, batch_size=batch_size,
            shuffle=False, num_workers=6, pin_memory=True,
            worker_init_fn=wif)

        return loader

### Multi-Head Learner

In [None]:
class MultiHead():
    """
    Object for initializing and training a multihead learner
    """
    def __init__(self, args, hp, data_conf):
        """
        Initialize multihead learner

        Params:
          - args:      Arguments from arg parse
          - hp:        dict of hyper-parameters config
          - data_conf: dcit of dataset config
        """
        self.args = args
        self.hp = hp

        num_tasks = len(data_conf['tasks'])
        num_classes = len(data_conf['tasks'][0])

        # Random seed
        torch.manual_seed(abs(args['seed']))
        np.random.seed(abs(args['seed']))

        # Initialize Network. code assumes all tasks have same no. of classes
        self.net = SmallConv(
                    num_task=num_tasks, 
                    num_cls=num_classes,
                    channels=1, 
                    avg_pool=2,
                    lin_size=80)

        # Get dataset
        dataset = SplitMNISTHandler(data_conf['tasks'])
        self.train_loader = dataset.get_data_loader(hp['batch_size'], train=True)

        # Loss and Optimizer
        self.optimizer = torch.optim.SGD(self.net.parameters(), lr=hp['lr'],
                                   momentum=0.9, nesterov=True,
                                   weight_decay=hp['l2_reg'])
        self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, args['epochs'] * len(self.train_loader))

    def train(self):
        """
        Train the multi-task learner
        """
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.net.to(device)

        # Train multi-head model
        for epoch in range(self.args['epochs']):
            train_loss = 0.0
            train_acc = 0.0
            batches = 0.0
            criterion = nn.CrossEntropyLoss()

            self.net.train()

            for dat, target in self.train_loader:
                self.optimizer.zero_grad()

                tasks, labels = target
                labels = labels.long()
                tasks = tasks.long()
                batch_size = int(labels.size()[0])

                dat = dat.to(device)
                labels = labels.to(device)
                tasks = tasks.to(device)

                # Forward/Back-prop
                out = self.net(dat, tasks)
                loss = criterion(out, labels)
                loss.backward()

                self.optimizer.step()

                self.lr_scheduler.step()

                # Compute Train metrics
                batches += batch_size
                train_loss += loss.item() * batch_size
                labels = labels.cpu().numpy()
                out = out.cpu().detach().numpy()
                train_acc += np.sum(labels == (np.argmax(out, axis=1)))

            print("Epoch = {}".format(epoch))
            print("Train loss = {:3f}".format(train_loss/batches))
            print("Train acc = {:3f}".format(train_acc/batches))
            print("\n")
            
        return self.net

### Model Zoo

In [153]:
class ModelZoo():
    """
    Object for the modelzoo
    """
    def __init__(self, args, data_conf, hp_conf):
        self.args = args
        self.tasks_info = data_conf['tasks']
        self.num_tasks = len(self.tasks_info)
        
        self.data_conf = data_conf
        self.hp_conf = hp_conf

        self.wts = np.array([1.0 for i in range(self.num_tasks)])
        self.learner_task_idx = []

        # Random generator for sampling tasks in every boosting iteration
        self.rng = np.random.default_rng(seed=100)

        # Store train and test predictions of individual models
        self.tr_preds = {}
        self.te_preds = {}
        for t_id in range(self.num_tasks):
            self.te_preds[t_id] = []
            self.tr_preds[t_id] = []

        # Store the individual models at each round
        self.modelzoo = {}
        self.synergistic_tasks = {}

    def sample_tasks(self, rounds):
        numsubtasks = min(2, rounds + 0) # number of tasks sampled at each round
        pr = self.wts[:rounds] / np.sum(self.wts[:rounds])
        if rounds != 0:
            learner_task_idx = self.rng.choice(rounds,
                                               numsubtasks - 1,
                                               replace=False, p=pr) # sampling the tasks using a multinomial distribution
        else:
            learner_task_idx = np.array([])

        # Manually add the newly seen task (boosting should
        # automatically select this task due to the the very large loss)
        learner_task_idx = np.append(learner_task_idx, int(rounds))
        learner_task_idx = np.array(learner_task_idx, dtype=np.int32)

        learner_task_info = [] # store the info of the samples tasks
        for idx in learner_task_idx:
            learner_task_info.append(self.tasks_info[idx])

        self.learner_task_info = learner_task_info
        self.learner_task_idx = learner_task_idx

        # change here
        learner_conf = deepcopy(self.data_conf)
        learner_conf['tasks'] = deepcopy(learner_task_info)
        self.synergistic_tasks[rounds] = learner_task_idx

        print("\n====== Round %d ======" % (rounds + 1))
        print("Sampled tasks: %s" % (str(learner_task_idx)))
        return learner_conf

    def add_learner(self, learner_conf, rounds):
        """
        Add a learner to the model-Zoo

        params:
          - learner_conf: dict describing Subset of tasks to train with
        """
        # Train a single "multi-head" learner and add it to the Model Zoo
        model = MultiHead(self.args, self.hp_conf, learner_conf)
        net = model.train()
        self.modelzoo[rounds] = net

        # Store all predictions of learner on train/test dataset 
        # This allows us to discard the learners weights
        tr_ret = self.fetch_outputs(net, learner_conf['tasks'], True)
        te_ret = self.fetch_outputs(net, learner_conf['tasks'], False)
        for idx, t_id in enumerate(self.learner_task_idx):
            self.tr_preds[t_id].append(tr_ret[idx])
            self.te_preds[t_id].append(te_ret[idx])

    def fetch_outputs(self, net, tasks, tr_flag=False):
        """
        Compute the outputs of newly trained learner on the tasks it
        was trained on. The predictions of different learners are not
        combined so that they can be used to compute the error of the
        Model Zoo at any stage. This allows us to discard the weights
        of the individual learner.

        params:
          - net:         Neural net of newest learner
          - l_task_info: Description of subset of tasks that neural net was
                         trained on
          - tr_flag:     Determines whether to use train/test set
        """
        dataset = SplitMNISTHandler(tasks)
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        test_loaders = []
        for t_id in range(len(tasks)):
            test_loaders.append(
                dataset.get_task_data_loader(t_id, 100, train=tr_flag))

        task_outputs = []
        net.eval()

        with torch.inference_mode():
            for dataloader in test_loaders:
                outputs = []
                for dat, target in dataloader:
                    tasks, labels = target
                    tasks = tasks.long()
                    labels = labels.long()

                    dat = dat.to(device)
                    labels = labels.to(device)
                    tasks = tasks.to(device)

                    out = net(dat, tasks)
                    out = nn.functional.softmax(out, dim=1)
                    out = out.cpu().detach().numpy()
                    outputs.append(out)
                outputs = np.concatenate(outputs)
                task_outputs.append(outputs)
        return task_outputs

    def evaluate(self, rounds: int):
        """
        Evaluate the entire Model Zoo (combination of all learners)
        on the train and test sets and log the results

        params:
          - rounds: Number of learners added to Zoo
        """
        tr_ret = self.evaluate_preds(self.tr_preds, True)
        te_ret = self.evaluate_preds(self.te_preds, False)

        def rnd(x):
            return list(np.round(x, 3))

        info = {
            'round': rounds,
            'TrainLoss': rnd(tr_ret['Loss']),
            'TrainAcc': rnd(tr_ret['Accuracy']),
            'TestLoss': rnd(te_ret['Loss']),
            'TestAcc': rnd(te_ret['Accuracy']),
            'last_learner_tasks': list(self.learner_task_idx),
            'last_learner_weights': rnd(self.wts)
        }

        avg_acc = np.mean(info['TrainAcc'][:rounds]) if rounds > 0 else 0.0
        allacc = str(list(np.round(info['TrainAcc'][:rounds], 2)))
        print("Average accuracy of all seen tasks: %.2f" % (avg_acc))
        print("Individual accuracies of all seen tasks:\n%s" % (allacc))
        return tr_ret['Loss'], info

    def evaluate_preds(self, preds, tr_flag):
        """
        Use the set of predictions from all learners to compute the error
        and the loss of the entire Model Zoo
        """
        dataset = SplitMNISTHandler(self.tasks_info)
        criterion = nn.NLLLoss()
        numcls = len(self.tasks_info[0])

        test_loaders = []
        for t_id in range(self.num_tasks):
            test_loaders.append(
                dataset.get_task_data_loader(t_id, 100, train=tr_flag))

        all_loss = []
        all_acc = []

        # Iterate over tasks and compute error/loss of Model Zoo on each task
        for task_id, dataloader in enumerate(test_loaders):
            count = 0
            acc = 0
            loss = 0

            # Compute the outputs of the entire Model Zoo by ensemble
            # averaging of the predictions of all learners
            if len(preds[task_id]) == 0:
                # If model has no prediction, output uniform probabilities
                numpts = len(dataloader.dataset)
                curpred = np.ones((numpts, numcls)) / numcls
            else:
                # If limited replay was, used apply a weighted ensemble. The
                # rationale is that we increase the weight of a learner if it
                # trained on more samples. This is true for the first learner
                # trained on a task (wts[0] is hence has higher weight)
                wts = np.ones(len(preds[task_id]))
                wts[0] = 1 / self.args['replay_frac']
                curpred = np.average(preds[task_id], axis=0, weights=wts)

            # Compute error/loss using outputs of Model Zoo (curpred)
            for dat, target in dataloader:
                tasks, labels = target
                tasks, labels = tasks.long(), labels.long()
                batch_size = int(labels.size()[0])

                dat = dat.cuda(non_blocking=True)
                tasks = tasks.cuda(non_blocking=True)

                out = curpred[count:count + batch_size]
                out = torch.log(torch.Tensor(out))

                loss += (criterion(out, labels).item()) * batch_size

                labels = labels.cpu().numpy()
                out = out.cpu().detach().numpy()
                acc += np.sum(labels == (np.argmax(out, axis=1)))
                count += batch_size

            all_loss.append(loss / count)
            all_acc.append(acc / count)

        info = {'Loss': all_loss,
                'Accuracy': all_acc,
                'train': tr_flag}
        return info

    def update_task_wts(self, losses):
        """
        Update the sampling weights based on the losses. self.wts should
        ideally be based on the transfer exponent $\rho$. We however, use
        the (noramlized) training loss like in boosting

        params:
          - losses: List of training losses on various tasks
        """
        losses = (losses - np.mean(losses)) / np.mean(losses)
        losses = np.exp(losses)
        losses = np.clip(losses, 0.0001, 1000)

        self.wts = losses
        return losses

    def train(self):
        """
        Train the Model Zoo
        """
        self.evaluate(0)
        for rounds in range(self.num_tasks):
            learner_conf = self.sample_tasks(rounds)
            self.add_learner(learner_conf, rounds)
            losses = self.evaluate(rounds + 0)
            self.update_task_wts(losses)

    # def predict(self, x, task_ids):
    #     """
    #     Predict using Model Zoo
    #     """
    #     for task in task_ids:
    #       for round in self.synergistic_tasks.keys():
    #         if task is in self.synergistic_tasks[round]:
    #           net = self.modelzoo[round]
              

### Train Model Zoo

In [None]:
import argparse

args = {
    'seed': 100,
    'tasks_per_round': 5,
    'epochs': 1,
    'replay_frac': 1.0
}
data_conf = {
    'data': 'mnist',
    'tasks': [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
}

hp_conf = {
    'batch_size': 16,
    'lr': 0.01,
    'l2_reg': 1e-5
}

zoo = ModelZoo(args, data_conf, hp_conf)

In [None]:
for rounds in range(zoo.num_tasks):
    learner_conf = zoo.sample_tasks(rounds)
    print(learner_conf)
    zoo.add_learner(learner_conf, rounds)
    losses, info = zoo.evaluate(rounds + 0)
    print(info)
    print(len(zoo.te_preds[0]))
    zoo.update_task_wts(losses)

### Inferece using Model Zoo

task_ids = [0, 1, 2, 3, 4]
tasks = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
dataset = SplitMNISTHandler(tasks)

In [None]:
test_loaders = []
for t_id in task_ids:
    test_loaders.append(dataset.get_task_data_loader(t_id, 100, train=False))

In [None]:
def get_network_output(net, dataloader):
    net.eval()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    net.to(device)
    outputs = []
    with torch.inference_mode():
        for dat, target in dataloader:
            tasks, labels = target
            tasks = tasks.long()
            labels = labels.long()

            dat = dat.to(device)
            labels = labels.to(device)
            tasks = tasks.to(device)

            out = net(dat, tasks)
            out = nn.functional.softmax(out, dim=1)
            out = out.cpu().detach().numpy()
            outputs.append(out)
    outputs = np.concatenate(outputs)
    return outputs

In [None]:
preds = {}
for t_id in range(len(tasks)):
    preds[t_id] = []
for t_id in task_ids:
    print("Task ID : {}".format(t_id))
    dataloader = test_loaders[t_id]
    for round in zoo.synergistic_tasks.keys():
        if t_id in zoo.synergistic_tasks[round]:
            print(round)
            net = zoo.modelzoo[round]
            outputs = get_network_output(net, dataloader)
            preds[t_id].append(outputs)
        else:
          continue