# Imports

In [None]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

from tqdm import tqdm
from typing import List

In [None]:
mnist_path = '/datasets/mnist'
cifar_path = '/datasets/cifar'
flower_path = '/datasets/flowers102'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Datasets

## CIFAR-100

In [None]:
def get_cifar100(batch_size = 64):

  # Define the transform to apply to the data
  train_transform = transforms.Compose([
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

  test_transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

  # Load the CIFAR-100 dataset
  train_dataset = torchvision.datasets.CIFAR100(cifar_path, train=True, download=True, transform=train_transform)
  test_dataset = torchvision.datasets.CIFAR100(cifar_path, train=False, download=True, transform=test_transform)

  
  # Define the number of classes per task
  num_classes_per_task = 20
  num_tasks = 5
  
  train_task_loaders = []
  test_task_loaders = []

  for i in range(num_tasks):
    classes = list(range(i*num_classes_per_task, (i+1)*num_classes_per_task))
    train_task_dataset = torch.utils.data.Subset(train_dataset, [j for j in range(len(train_dataset)) if train_dataset[j][1] in classes])
    test_task_dataset = torch.utils.data.Subset(test_dataset, [j for j in range(len(test_dataset)) if test_dataset[j][1] in classes])
    train_loader = DataLoader(train_task_dataset, batch_size = batch_size, shuffle = True, num_workers = 2)
    test_loader = DataLoader(test_task_dataset, batch_size = batch_size, shuffle = False)

    train_task_loaders.append(train_loader)
    test_task_loaders.append(test_loader)
  
  return train_task_loaders, test_task_loaders

## Permuted MNIST

In [None]:
class PermuteMNISTTask:
    def __init__(self, permutation):
        self.permutation = permutation

    def __call__(self, x):
        x = x.view(-1, 32 * 32)
        x = x[:, self.permutation]
        return x.view(-1, 32, 32)

def get_permuted_mnist(batch_size = 64):

  # Define random permutations
  permutations = [torch.randperm(32 * 32) for i in range(5)]

  # Create transforms
  tasks = []
  for permutation in permutations:
      tasks.append(transforms.Compose([
          transforms.Grayscale(num_output_channels=3),
          torchvision.transforms.Resize(32),
          transforms.ToTensor(),
          PermuteMNISTTask(permutation)
      ]))              

  train_loaders = []
  test_loaders = []

  # Create tasks    
  for task in tasks:
    train_dataset = torchvision.datasets.MNIST(mnist_path, train=True, download=True, transform = task)
    test_dataset = torchvision.datasets.MNIST(mnist_path, train=False, download=True, transform = task)

    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = 2)
    test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

    train_loaders.append(train_loader)
    test_loaders.append(test_loader)

  return train_loaders, test_loaders

## Flowers-102 

In [None]:
def get_flowers102(batch_size = 16):
  # Create transforms
  train_transform = transforms.Compose([
      transforms.Resize([256, 256]),
      transforms.CenterCrop(224),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

  test_transform = transforms.Compose([
      transforms.Resize([256, 256]),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
  
  train_dataset = torchvision.datasets.Flowers102(flower_path, split="train", download=True, transform=train_transform)
  test_dataset = torchvision.datasets.Flowers102(flower_path, split="test", download=True, transform=test_transform)

  num_classes_per_task = 17
  num_tasks = 6

  train_task_loaders = []
  test_task_loaders = []
  
  # Create tasks 
  for i in range(num_tasks):
    classes = list(range(i*num_classes_per_task, (i+1)*num_classes_per_task))
    train_task_dataset = torch.utils.data.Subset(train_dataset, [j for j in range(len(train_dataset)) if train_dataset[j][1] in classes])
    test_task_dataset = torch.utils.data.Subset(test_dataset, [j for j in range(len(test_dataset)) if test_dataset[j][1] in classes])
    train_loader = DataLoader(train_task_dataset, batch_size = batch_size, shuffle = True, num_workers = 2)
    test_loader = DataLoader(test_task_dataset, batch_size = batch_size, shuffle = False)

    train_task_loaders.append(train_loader)
    test_task_loaders.append(test_loader)

  return train_task_loaders, test_task_loaders

## Diff Datasets

In [None]:
def get_diff_datasets(batch_size = 32):
  """ Create an experiment with MNIST, CIFAR100 and Flowers102 as different tasks  """

  trainloaders = []
  testloaders = []

  #CIFAR100 and Flowers102 transforms
  train_transform = transforms.Compose([
      transforms.Resize([256, 256]),
      transforms.CenterCrop(224),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
  
  test_transform = transforms.Compose([
      transforms.Resize([256, 256]),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

  #MNIST transform
  mnist_train = transforms.Compose([
      transforms.Grayscale(num_output_channels=3),
      train_transform
  ])

  mnist_test = transforms.Compose([
      transforms.Grayscale(num_output_channels=3),
      test_transform
  ])

  # Load MNIST
  mnist_train = torchvision.datasets.MNIST(mnist_path, train=True, download=True, transform = mnist_train)
  mnist_test = torchvision.datasets.MNIST(mnist_path, train=False, download=True, transform = mnist_test)

  mnist_train_loader = DataLoader(mnist_train, batch_size = batch_size, shuffle = True, num_workers = 2)
  mnist_test_loader = DataLoader(mnist_test, batch_size = batch_size, shuffle = False)

  trainloaders.append(mnist_train_loader)
  testloaders.append(mnist_test_loader)

  #Load CIFAR100
  cifar_train = torchvision.datasets.CIFAR100(cifar_path, train=True, download=True, transform=train_transform)
  cifar_test = torchvision.datasets.CIFAR100(cifar_path, train=False, download=True, transform=test_transform)

  cifar_train_loader = DataLoader(cifar_train, batch_size = batch_size, shuffle = True, num_workers = 2)
  cifar_test_loader = DataLoader(cifar_test, batch_size = batch_size, shuffle = False)

  trainloaders.append(cifar_train_loader)
  testloaders.append(cifar_test_loader)

  # Load Flowers102
  flowers_train = torchvision.datasets.Flowers102(flower_path, split="train", download=True, transform=train_transform)
  flowers_test = torchvision.datasets.Flowers102(flower_path, split="test", download=True, transform=test_transform)

  flowers_train_loader = DataLoader(flowers_train, batch_size = batch_size, shuffle = True, num_workers = 2)
  flowers_test_loader = DataLoader(flowers_test, batch_size = batch_size, shuffle = False)
  
  trainloaders.append(flowers_train_loader)
  testloaders.append(flowers_test_loader)

  return trainloaders, testloaders

# Baselines

## Finetune

In [None]:
class Finetune:
  def __init__(self, num_outputs, epochs, lr):

    self.num_outputs = num_outputs

    self.model = models.resnet34(weights = None)
    self.model.fc = nn.Linear(self.model.fc.in_features, self.num_outputs)

    self.epochs = epochs
    self.lr = lr
    self.lossFunc = nn.CrossEntropyLoss()
    self.task_idx = 0

  def train(self, trainloader, testloader):
    self.task_idx += 1
    self.model.train()
    self.model.to(device)
    optimizer = torch.optim.Adam(self.model.parameters(), lr = self.lr)

    for epoch in range(self.epochs):      
      for inputs, labels in tqdm(trainloader, desc = "Training task {}, Epoch {}/{}".format(self.task_idx, epoch + 1, self.epochs)):

        inputs = inputs.to(device)
        labels = labels.to(device)
        labels = labels % self.num_outputs

        optimizer.zero_grad()

        outputs = self.model(inputs)
        loss = self.lossFunc(outputs, labels)

        loss.backward()
        optimizer.step()
        
    self.eval(self.task_idx, testloader)
  
  def eval(self, task_idx, testloader):
    eval_total = 0
    eval_correct = 0
    self.model.eval()
    self.model.to(device)
    for inputs, labels in tqdm(testloader, desc = "Evaluating Task {}".format(task_idx)):
      
      inputs = inputs.to(device)
      labels = labels.to(device)
      labels = labels % self.num_outputs

      outputs = self.model(inputs)

      _, predictions = torch.max(outputs, dim = 1)
      eval_total += labels.shape[0]
      eval_correct += int((predictions == labels).sum()) 
    print("Evaluation Accucary on Task {} : {}".format(task_idx, eval_correct/eval_total))

## Indivindual Networks

In [None]:
class IndivindualNets:
  def __init__(self, num_models, num_outputs, epochs):
    self.num_models = num_models

    if isinstance(num_outputs, List):
      self.num_outputs = num_outputs
    else:
      self.num_outputs = [num_outputs for _ in range(self.num_models)]

    self.models = [models.resnet34(weights = None) for _ in range (self.num_models)]

    #Replace the classfier for each model to match the number of classes of each task
    for index, model in enumerate(self.models):
      model.fc = nn.Linear(model.fc.in_features, self.num_outputs[index])
    
    self.epochs = epochs
    self.lossFunc = nn.CrossEntropyLoss()
    self.task_idx = 0
  
  def train(self, trainloader, testloader, lr):
    self.task_idx += 1
    model = self.models[self.task_idx - 1]
    model.train()
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)

    for epoch in range(self.epochs):
      for inputs, labels in tqdm(trainloader, desc = "Training task {}, Epoch {}/{}".format(self.task_idx, epoch + 1, self.epochs)):
        
        inputs = inputs.to(device)
        labels = labels.to(device)
        labels = labels % self.num_outputs[self.task_idx - 1]

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = self.lossFunc(outputs, labels)
        loss.backward()
        optimizer.step()

    self.eval(self.task_idx, testloader)
  
  def eval(self, task_idx, testloader):
    model = self.models[task_idx - 1]
    model.eval()
    model.to(device)

    eval_total = 0
    eval_correct = 0
    for inputs, labels in tqdm(testloader, desc = "Evaluating Task {}".format(task_idx)):
      inputs = inputs.to(device)
      labels = labels.to(device)
      labels = labels % self.num_outputs[task_idx - 1]

      outputs = model(inputs)

      _, predictions = torch.max(outputs, dim = 1)
      eval_total += labels.shape[0]
      eval_correct += int((predictions == labels).sum()) 
    print("Evaluation Accucary on Task {} : {}".format(task_idx, eval_correct / eval_total))

# CIFAR-100 Example

In [None]:
trainloaders, testloaders = get_cifar100()
num_outputs = 20
lr = 1e-3
epochs = 10

## Finetune

In [None]:
finetune = Finetune(num_outputs, epochs, lr)
for trainloader, testloader in zip(trainloaders, testloaders):
  finetune.train(trainloader, testloader)

In [None]:
for task_idx, testloader in enumerate(testloaders):
  finetune.eval(task_idx + 1, testloader)

## Indivindual Networks

In [None]:
indNets = IndivindualNets(5, num_outputs, epochs)
for trainloader, testloader in zip(trainloaders, testloaders):
  indNets.train(trainloader, testloader, lr)

In [None]:
for task_idx, testloader in enumerate(testloaders):
  indNets.eval(task_idx + 1, testloader)

# Permuted MNIST Example

In [None]:
trainloaders, testloaders = get_permuted_mnist()
num_outputs = 10
lr = 1e-3
epochs = 5

## Finetune

In [None]:
finetune = Finetune(num_outputs, epochs, lr)
for trainloader, testloader in zip(trainloaders, testloaders):
  finetune.train(trainloader, testloader)

In [None]:
for task_idx, testloader in enumerate(testloaders):
  finetune.eval(task_idx + 1, testloader)

## Indivindual Networks

In [None]:
indNets = IndivindualNets(5, num_outputs, epochs)
for trainloader, testloader in zip(trainloaders, testloaders):
  indNets.train(trainloader, testloader, lr)

In [None]:
for task_idx, testloader in enumerate(testloaders):
  indNets.eval(task_idx + 1, testloader)

# Flowers-102 Example

In [None]:
trainloaders, testloaders = get_flowers102()
num_outputs = 17
lr = 1e-3
epochs = 20

## Finetune

In [None]:
finetune = Finetune(num_outputs, epochs, lr)
for trainloader, testloader in zip(trainloaders, testloaders):
  finetune.train(trainloader, testloader)

In [None]:
for task_idx, testloader in enumerate(testloaders):
  finetune.eval(task_idx + 1, testloader)

Evaluating Task 1: 100%|██████████| 40/40 [00:06<00:00,  5.89it/s]


Evaluation Accucary on Task 1 : 0.06050955414012739


Evaluating Task 2: 100%|██████████| 39/39 [00:07<00:00,  5.19it/s]


Evaluation Accucary on Task 2 : 0.052202283849918436


Evaluating Task 3: 100%|██████████| 77/77 [00:14<00:00,  5.38it/s]


Evaluation Accucary on Task 3 : 0.021103896103896104


Evaluating Task 4: 100%|██████████| 57/57 [00:10<00:00,  5.27it/s]


Evaluation Accucary on Task 4 : 0.05077262693156733


Evaluating Task 5: 100%|██████████| 103/103 [00:18<00:00,  5.59it/s]


Evaluation Accucary on Task 5 : 0.09363525091799266


Evaluating Task 6: 100%|██████████| 71/71 [00:12<00:00,  5.46it/s]

Evaluation Accucary on Task 6 : 0.29313380281690143





## Indivindual Networks

In [None]:
indNets = IndivindualNets(6, num_outputs, epochs)
for trainloader, testloader in zip(trainloaders, testloaders):
  indNets.train(trainloader, testloader, 1e-4)

Training task 1, Epoch 1/20: 100%|██████████| 11/11 [00:02<00:00,  5.27it/s]
Training task 1, Epoch 2/20: 100%|██████████| 11/11 [00:01<00:00,  6.02it/s]
Training task 1, Epoch 3/20: 100%|██████████| 11/11 [00:01<00:00,  6.32it/s]
Training task 1, Epoch 4/20: 100%|██████████| 11/11 [00:02<00:00,  4.73it/s]
Training task 1, Epoch 5/20: 100%|██████████| 11/11 [00:02<00:00,  4.70it/s]
Training task 1, Epoch 6/20: 100%|██████████| 11/11 [00:01<00:00,  6.13it/s]
Training task 1, Epoch 7/20: 100%|██████████| 11/11 [00:01<00:00,  6.19it/s]
Training task 1, Epoch 8/20: 100%|██████████| 11/11 [00:01<00:00,  6.33it/s]
Training task 1, Epoch 9/20: 100%|██████████| 11/11 [00:01<00:00,  6.27it/s]
Training task 1, Epoch 10/20: 100%|██████████| 11/11 [00:01<00:00,  6.44it/s]
Training task 1, Epoch 11/20: 100%|██████████| 11/11 [00:02<00:00,  3.91it/s]
Training task 1, Epoch 12/20: 100%|██████████| 11/11 [00:01<00:00,  5.99it/s]
Training task 1, Epoch 13/20: 100%|██████████| 11/11 [00:01<00:00,  6.39i

Evaluation Accucary on Task 1 : 0.4856687898089172


Training task 2, Epoch 1/20: 100%|██████████| 11/11 [00:02<00:00,  4.40it/s]
Training task 2, Epoch 2/20: 100%|██████████| 11/11 [00:01<00:00,  6.47it/s]
Training task 2, Epoch 3/20: 100%|██████████| 11/11 [00:01<00:00,  6.34it/s]
Training task 2, Epoch 4/20: 100%|██████████| 11/11 [00:01<00:00,  6.48it/s]
Training task 2, Epoch 5/20: 100%|██████████| 11/11 [00:01<00:00,  6.41it/s]
Training task 2, Epoch 6/20: 100%|██████████| 11/11 [00:01<00:00,  6.31it/s]
Training task 2, Epoch 7/20: 100%|██████████| 11/11 [00:02<00:00,  4.33it/s]
Training task 2, Epoch 8/20: 100%|██████████| 11/11 [00:02<00:00,  5.37it/s]
Training task 2, Epoch 9/20: 100%|██████████| 11/11 [00:01<00:00,  6.04it/s]
Training task 2, Epoch 10/20: 100%|██████████| 11/11 [00:01<00:00,  6.57it/s]
Training task 2, Epoch 11/20: 100%|██████████| 11/11 [00:01<00:00,  6.31it/s]
Training task 2, Epoch 12/20: 100%|██████████| 11/11 [00:01<00:00,  6.35it/s]
Training task 2, Epoch 13/20: 100%|██████████| 11/11 [00:01<00:00,  6.23i

Evaluation Accucary on Task 2 : 0.42251223491027734


Training task 3, Epoch 1/20: 100%|██████████| 11/11 [00:01<00:00,  6.27it/s]
Training task 3, Epoch 2/20: 100%|██████████| 11/11 [00:01<00:00,  6.37it/s]
Training task 3, Epoch 3/20: 100%|██████████| 11/11 [00:01<00:00,  5.66it/s]
Training task 3, Epoch 4/20: 100%|██████████| 11/11 [00:02<00:00,  4.16it/s]
Training task 3, Epoch 5/20: 100%|██████████| 11/11 [00:01<00:00,  6.39it/s]
Training task 3, Epoch 6/20: 100%|██████████| 11/11 [00:01<00:00,  6.38it/s]
Training task 3, Epoch 7/20: 100%|██████████| 11/11 [00:01<00:00,  6.33it/s]
Training task 3, Epoch 8/20: 100%|██████████| 11/11 [00:01<00:00,  6.18it/s]
Training task 3, Epoch 9/20: 100%|██████████| 11/11 [00:01<00:00,  6.28it/s]
Training task 3, Epoch 10/20: 100%|██████████| 11/11 [00:02<00:00,  4.41it/s]
Training task 3, Epoch 11/20: 100%|██████████| 11/11 [00:02<00:00,  5.09it/s]
Training task 3, Epoch 12/20: 100%|██████████| 11/11 [00:01<00:00,  6.38it/s]
Training task 3, Epoch 13/20: 100%|██████████| 11/11 [00:01<00:00,  6.28i

Evaluation Accucary on Task 3 : 0.4180194805194805


Training task 4, Epoch 1/20: 100%|██████████| 11/11 [00:01<00:00,  6.19it/s]
Training task 4, Epoch 2/20: 100%|██████████| 11/11 [00:01<00:00,  5.87it/s]
Training task 4, Epoch 3/20: 100%|██████████| 11/11 [00:02<00:00,  4.04it/s]
Training task 4, Epoch 4/20: 100%|██████████| 11/11 [00:01<00:00,  6.27it/s]
Training task 4, Epoch 5/20: 100%|██████████| 11/11 [00:01<00:00,  6.28it/s]
Training task 4, Epoch 6/20: 100%|██████████| 11/11 [00:01<00:00,  6.11it/s]
Training task 4, Epoch 7/20: 100%|██████████| 11/11 [00:01<00:00,  6.31it/s]
Training task 4, Epoch 8/20: 100%|██████████| 11/11 [00:01<00:00,  6.24it/s]
Training task 4, Epoch 9/20: 100%|██████████| 11/11 [00:02<00:00,  4.60it/s]
Training task 4, Epoch 10/20: 100%|██████████| 11/11 [00:02<00:00,  4.81it/s]
Training task 4, Epoch 11/20: 100%|██████████| 11/11 [00:01<00:00,  6.16it/s]
Training task 4, Epoch 12/20: 100%|██████████| 11/11 [00:01<00:00,  6.34it/s]
Training task 4, Epoch 13/20: 100%|██████████| 11/11 [00:01<00:00,  6.47i

Evaluation Accucary on Task 4 : 0.7240618101545254


Training task 5, Epoch 1/20: 100%|██████████| 11/11 [00:01<00:00,  6.37it/s]
Training task 5, Epoch 2/20: 100%|██████████| 11/11 [00:01<00:00,  6.29it/s]
Training task 5, Epoch 3/20: 100%|██████████| 11/11 [00:01<00:00,  6.17it/s]
Training task 5, Epoch 4/20: 100%|██████████| 11/11 [00:02<00:00,  3.93it/s]
Training task 5, Epoch 5/20: 100%|██████████| 11/11 [00:01<00:00,  6.48it/s]
Training task 5, Epoch 6/20: 100%|██████████| 11/11 [00:01<00:00,  6.34it/s]
Training task 5, Epoch 7/20: 100%|██████████| 11/11 [00:01<00:00,  6.24it/s]
Training task 5, Epoch 8/20: 100%|██████████| 11/11 [00:01<00:00,  6.33it/s]
Training task 5, Epoch 9/20: 100%|██████████| 11/11 [00:01<00:00,  6.31it/s]
Training task 5, Epoch 10/20: 100%|██████████| 11/11 [00:02<00:00,  5.18it/s]
Training task 5, Epoch 11/20: 100%|██████████| 11/11 [00:02<00:00,  4.71it/s]
Training task 5, Epoch 12/20: 100%|██████████| 11/11 [00:01<00:00,  6.33it/s]
Training task 5, Epoch 13/20: 100%|██████████| 11/11 [00:01<00:00,  6.59i

Evaluation Accucary on Task 5 : 0.3812729498164015


Training task 6, Epoch 1/20: 100%|██████████| 11/11 [00:02<00:00,  4.17it/s]
Training task 6, Epoch 2/20: 100%|██████████| 11/11 [00:01<00:00,  6.31it/s]
Training task 6, Epoch 3/20: 100%|██████████| 11/11 [00:01<00:00,  6.42it/s]
Training task 6, Epoch 4/20: 100%|██████████| 11/11 [00:01<00:00,  6.50it/s]
Training task 6, Epoch 5/20: 100%|██████████| 11/11 [00:01<00:00,  6.47it/s]
Training task 6, Epoch 6/20: 100%|██████████| 11/11 [00:01<00:00,  6.60it/s]
Training task 6, Epoch 7/20: 100%|██████████| 11/11 [00:02<00:00,  5.29it/s]
Training task 6, Epoch 8/20: 100%|██████████| 11/11 [00:02<00:00,  4.44it/s]
Training task 6, Epoch 9/20: 100%|██████████| 11/11 [00:01<00:00,  6.32it/s]
Training task 6, Epoch 10/20: 100%|██████████| 11/11 [00:01<00:00,  6.43it/s]
Training task 6, Epoch 11/20: 100%|██████████| 11/11 [00:01<00:00,  6.39it/s]
Training task 6, Epoch 12/20: 100%|██████████| 11/11 [00:01<00:00,  6.42it/s]
Training task 6, Epoch 13/20: 100%|██████████| 11/11 [00:01<00:00,  6.55i

Evaluation Accucary on Task 6 : 0.3829225352112676





In [None]:
for task_idx, testloader in enumerate(testloaders):
  indNets.eval(task_idx + 1, testloader)

Evaluating Task 1: 100%|██████████| 40/40 [00:07<00:00,  5.26it/s]


Evaluation Accucary on Task 1 : 0.4856687898089172


Evaluating Task 2: 100%|██████████| 39/39 [00:07<00:00,  5.29it/s]


Evaluation Accucary on Task 2 : 0.42251223491027734


Evaluating Task 3: 100%|██████████| 77/77 [00:14<00:00,  5.41it/s]


Evaluation Accucary on Task 3 : 0.4180194805194805


Evaluating Task 4: 100%|██████████| 57/57 [00:10<00:00,  5.33it/s]


Evaluation Accucary on Task 4 : 0.7240618101545254


Evaluating Task 5: 100%|██████████| 103/103 [00:18<00:00,  5.60it/s]


Evaluation Accucary on Task 5 : 0.3812729498164015


Evaluating Task 6: 100%|██████████| 71/71 [00:12<00:00,  5.48it/s]

Evaluation Accucary on Task 6 : 0.3829225352112676





# Diff Datasets Example

In [None]:
trainloaders, testloaders = get_diff_datasets()
num_outputs = [10, 100, 102]
lr = 1e-3
epochs = 10

## Finetune

In [None]:
finetune = Finetune(num_outputs[-1], epochs, lr)
for trainloader, testloader in zip(trainloaders, testloaders):
  finetune.train(trainloader, testloader)

In [None]:
for task_idx, testloader in enumerate(testloaders):
  finetune.eval(task_idx + 1, testloader)

## Indivindual Networks

In [None]:
lr = [1e-3, 1e-3, 1e-4]
indNets = IndivindualNets(3, num_outputs, epochs)
for trainloader, testloader, lr in zip(trainloaders, testloaders, lr):
  indNets.train(trainloader, testloader, lr)

In [None]:
for task_idx, testloader in enumerate(testloaders):
  indNets.eval(task_idx + 1, testloader)