# Imports


In [None]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

from tqdm import tqdm
from typing import List

In [None]:
mnist_path = '/datasets/mnist'
cifar_path = '/datasets/cifar'
flower_path = '/datasets/flowers102'
packnet_path = 'packnet_weights.pth'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Datasets

## CIFAR-100

In [None]:
def get_cifar100(batch_size = 64):

  # Define the transform to apply to the data
  train_transform = transforms.Compose([
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

  test_transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

  # Load the CIFAR-100 dataset
  train_dataset = torchvision.datasets.CIFAR100(cifar_path, train=True, download=True, transform=train_transform)
  test_dataset = torchvision.datasets.CIFAR100(cifar_path, train=False, download=True, transform=test_transform)

  
  # Define the number of classes per task
  num_classes_per_task = 20
  num_tasks = 5
  
  train_task_loaders = []
  test_task_loaders = []

  for i in range(num_tasks):
    classes = list(range(i*num_classes_per_task, (i+1)*num_classes_per_task))
    train_task_dataset = torch.utils.data.Subset(train_dataset, [j for j in range(len(train_dataset)) if train_dataset[j][1] in classes])
    test_task_dataset = torch.utils.data.Subset(test_dataset, [j for j in range(len(test_dataset)) if test_dataset[j][1] in classes])
    train_loader = DataLoader(train_task_dataset, batch_size = batch_size, shuffle = True, num_workers = 2)
    test_loader = DataLoader(test_task_dataset, batch_size = batch_size, shuffle = False)

    train_task_loaders.append(train_loader)
    test_task_loaders.append(test_loader)
  
  return train_task_loaders, test_task_loaders

## Permuted MNIST

In [None]:
class PermuteMNISTTask:
    def __init__(self, permutation):
        self.permutation = permutation

    def __call__(self, x):
        x = x.view(-1, 32 * 32)
        x = x[:, self.permutation]
        return x.view(-1, 32, 32)

def get_permuted_mnist(batch_size = 64):

  # Define random permutations
  permutations = [torch.randperm(32 * 32) for i in range(5)]

  # Create transforms
  tasks = []
  for permutation in permutations:
      tasks.append(transforms.Compose([
          transforms.Grayscale(num_output_channels=3),
          torchvision.transforms.Resize(32),
          transforms.ToTensor(),
          PermuteMNISTTask(permutation)
      ]))              

  train_loaders = []
  test_loaders = []

  # Create tasks    
  for task in tasks:
    train_dataset = torchvision.datasets.MNIST(mnist_path, train=True, download=True, transform = task)
    test_dataset = torchvision.datasets.MNIST(mnist_path, train=False, download=True, transform = task)

    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = 2)
    test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

    train_loaders.append(train_loader)
    test_loaders.append(test_loader)

  return train_loaders, test_loaders

## Flowers-102

In [None]:
def get_flowers102(batch_size = 16):
  # Create transforms
  train_transform = transforms.Compose([
      transforms.Resize([256, 256]),
      transforms.CenterCrop(224),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

  test_transform = transforms.Compose([
      transforms.Resize([256, 256]),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
  
  train_dataset = torchvision.datasets.Flowers102(flower_path, split="train", download=True, transform=train_transform)
  test_dataset = torchvision.datasets.Flowers102(flower_path, split="test", download=True, transform=test_transform)

  num_classes_per_task = 17
  num_tasks = 6

  train_task_loaders = []
  test_task_loaders = []
  
  # Create tasks 
  for i in range(num_tasks):
    classes = list(range(i*num_classes_per_task, (i+1)*num_classes_per_task))
    train_task_dataset = torch.utils.data.Subset(train_dataset, [j for j in range(len(train_dataset)) if train_dataset[j][1] in classes])
    test_task_dataset = torch.utils.data.Subset(test_dataset, [j for j in range(len(test_dataset)) if test_dataset[j][1] in classes])
    train_loader = DataLoader(train_task_dataset, batch_size = batch_size, shuffle = True, num_workers = 2)
    test_loader = DataLoader(test_task_dataset, batch_size = batch_size, shuffle = False)

    train_task_loaders.append(train_loader)
    test_task_loaders.append(test_loader)

  return train_task_loaders, test_task_loaders

## Diff Datasets

In [None]:
def get_diff_datasets(batch_size = 32):
  """ Create an experiment with MNIST, CIFAR100 and Flowers102 as different tasks  """

  trainloaders = []
  testloaders = []

  #CIFAR100 and Flowers102 transforms
  train_transform = transforms.Compose([
      transforms.Resize([256, 256]),
      transforms.CenterCrop(224),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
  
  test_transform = transforms.Compose([
      transforms.Resize([256, 256]),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

  #MNIST transform
  mnist_train = transforms.Compose([
      transforms.Grayscale(num_output_channels=3),
      train_transform
  ])

  mnist_test = transforms.Compose([
      transforms.Grayscale(num_output_channels=3),
      test_transform
  ])

  # Load MNIST
  mnist_train = torchvision.datasets.MNIST(mnist_path, train=True, download=True, transform = mnist_train)
  mnist_test = torchvision.datasets.MNIST(mnist_path, train=False, download=True, transform = mnist_test)

  mnist_train_loader = DataLoader(mnist_train, batch_size = batch_size, shuffle = True, num_workers = 2)
  mnist_test_loader = DataLoader(mnist_test, batch_size = batch_size, shuffle = False)

  trainloaders.append(mnist_train_loader)
  testloaders.append(mnist_test_loader)

  #Load CIFAR100
  cifar_train = torchvision.datasets.CIFAR100(cifar_path, train=True, download=True, transform=train_transform)
  cifar_test = torchvision.datasets.CIFAR100(cifar_path, train=False, download=True, transform=test_transform)

  cifar_train_loader = DataLoader(cifar_train, batch_size = batch_size, shuffle = True, num_workers = 2)
  cifar_test_loader = DataLoader(cifar_test, batch_size = batch_size, shuffle = False)

  trainloaders.append(cifar_train_loader)
  testloaders.append(cifar_test_loader)

  # Load Flowers102
  flowers_train = torchvision.datasets.Flowers102(flower_path, split="train", download=True, transform=train_transform)
  flowers_test = torchvision.datasets.Flowers102(flower_path, split="test", download=True, transform=test_transform)

  flowers_train_loader = DataLoader(flowers_train, batch_size = batch_size, shuffle = True, num_workers = 2)
  flowers_test_loader = DataLoader(flowers_test, batch_size = batch_size, shuffle = False)
  
  trainloaders.append(flowers_train_loader)
  testloaders.append(flowers_test_loader)

  return trainloaders, testloaders


# Packnet

## Network

In [None]:
class PacknetResNet(nn.Module):
  """ResNet-50"""

  def __init__(self, make_model = True):
    super(PacknetResNet, self).__init__()

    resnet = models.resnet50(weights = None)
    
    self.datasets = [] 
    self.classifiers = nn.ModuleList()
    self.shared = nn.Sequential()
   
    for name, module in resnet.named_children():
      if name != 'fc':
        self.shared.add_module(name, module)
        
    #model.set_dataset() has to be called explicity, else model won't work
    self.classifier = None

  def train_nobn(self, mode = True):
    """Override the default module train"""
    super(PacknetResNet, self).train(mode)

    # Set the BNs to eval mode so that the running means and averages do not update
    for module in self.shared.modules():
      if 'BatchNorm' in str(type(module)):
        module.eval()
  
  def add_dataset(self, dataset, num_outputs):
    """Adds a new dataset and a new classifier to train for the new task. """
    if dataset not in self.datasets:
      print("Adding new dataset : {} with number of outputs : {}".format(dataset, num_outputs))
      self.datasets.append(dataset)
      self.classifiers.append(nn.Linear(2048, num_outputs))

  def set_dataset(self, dataset):
    """Change the active classifier"""
    assert dataset in self.datasets
    self.classifier = self.classifiers[self.datasets.index(dataset)]
    print("Setting dataset : {} with classifier : {}".format(dataset, self.classifier))
  
  def forward(self, x):
    x = self.shared(x)
    x = x.view(x.size(0), - 1)
    x = self.classifier(x)
    return x

## Pruner

In [None]:
class SparsePruner(object):
  """Performs pruning on the given model"""

  def __init__(self, model, prune_perc, previous_masks, train_bias = False, train_bn = False):
    self.model = model 
    self.prune_perc = prune_perc
    self.train_bias = train_bias
    self.train_bn = train_bn

    self.current_masks = None
    self.previous_masks = previous_masks
    valid_key = list(previous_masks.keys())[0]
    self.current_dataset_idx = previous_masks[valid_key].max()
  
  def pruning_mask(self, weights, previous_mask, layer_idx):
    """Ranks weights by magnitude. Sets all below kth to 0.Returns pruned mask"""

    # Select all prunable weights, i.e. belonging to the current dataset.
    previous_mask = previous_mask.to(device)
    prunable_weights = weights[previous_mask.eq(self.current_dataset_idx.to(device))]
    abs_prunable = prunable_weights.abs()
    cutoff_rank = round(self.prune_perc * prunable_weights.numel())
    cutoff_value = abs_prunable.view(-1).cpu().kthvalue(cutoff_rank)[0].item()

    #Remove those weights which are below cutoff and belong to
    #the current dataset that we are training for.
    remove_mask = weights.abs().le(cutoff_value) * previous_mask.eq(self.current_dataset_idx.to(device))

    
    previous_mask[remove_mask.eq(1)] = 0
    mask = previous_mask
    print('Layer #%d, pruned %d/%d (%.2f%%) (Total in layer: %d)' %
              (layer_idx, mask.eq(0).sum(), prunable_weights.numel(),
               100 * mask.eq(0).sum() / prunable_weights.numel(), weights.numel()))
    return mask
  
  def prune(self):
    """Gets prunning mask for each layer, based on previous masks.
       Sets the self.current_masks to the computed pruning masks
    """
    print('Pruning fo dataset idx: %d' %(self.current_dataset_idx))
    assert not self.current_masks, 'Current mask is not empty which means that pruning is already done for this task '
    self.current_masks = {}

    print('Pruning each layer by removing %.2f%% of values' % (100 * self.prune_perc))

    for module_idx, module in enumerate(self.model.shared.modules()):
      if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        mask = self.pruning_mask(module.weight.data, self.previous_masks[module_idx], module_idx)
        self.current_masks[module_idx] = mask.to(device)
        #Set pruned weights to 0
        weight = module.weight.data
        weight[self.current_masks[module_idx].eq(0)] = 0.0

  def make_grads_zero(self):
    """Sets grads of fixed weights to 0."""
    assert self.current_masks
    for module_idx, module in enumerate(self.model.shared.modules()):
      if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        layer_mask = self.current_masks[module_idx]

        # Set grads of all weights not belonging to current dataset to 0.
        if module.weight.grad is not None:
          module.weight.grad.data[layer_mask.ne(self.current_dataset_idx.to(device))] = 0
          if not self.train_bias:
            #Biases are fixed
            if module.bias is not None:
              module.bias.grad.fill_(0)
      
      elif 'BatchNorm' in str(type(module)):
        # Set grads of batchnorm to 0
        if not self.train_bn:
          module.weight.grad.data.fill_(0)
          module.bias.grad.fill_(0)
  
  def make_pruned_zero(self):
    """Makes pruned weight 0"""
    assert self.current_masks

    for module_idx, module in enumerate(self.model.shared.modules()):
      if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        layer_mask = self.current_masks[module_idx]
        module.weight.data[layer_mask.eq(0)] = 0.0

  def apply_mask(self, dataset_idx):
    """Applies mask for a specific task for evaluation"""
    for module_idx, module in enumerate(self.model.shared.modules()):
      if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        weight = module.weight.data
        mask = self.previous_masks[module_idx].to(device)
        weight[mask.eq(0)] = 0.0
        weight[mask.gt(dataset_idx)] = 0.0
  
  def restore_biases(self, biases):
    """Use the given biases to replace existing biases"""
    for module_idx, module in enumerate(self.model.shared.modules()):
      if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        if module.bias is not None:
          module.bias.data.copy_(biases[module_idx])

  def get_biases(self):
    """Gets a copy of the current biases"""
    biases = {}
    for module_idx, module in enumerate(self.model.shared.modules()):
      if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        if module.bias is not None:
          biases[module_idx] = module.bias.data.clone()
    return biases
  
  def make_finetuning_mask(self):
    """Turns previously pruned weights (id = 0) into trainable weights for the current dataset (id = current_dataset_id)"""
    assert self.previous_masks
    self.current_dataset_idx += 1
    print("Current Dataset Idx:", self.current_dataset_idx)

    for module_idx, module in enumerate(self.model.shared.modules()):
      if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        mask = self.previous_masks[module_idx]
        mask[mask.eq(0)] = self.current_dataset_idx
    
    self.current_masks = self.previous_masks

## Main Class

In [None]:
class Packnet(object):

  def __init__(self, args, model, previous_masks, trainloader, testloader):
    self.args = args 
    self.model = model

    self.train_data_loader = trainloader
    self.test_data_loader = testloader

    self.criterion = nn.CrossEntropyLoss()

    self.pruner = SparsePruner(
        self.model, self.args['prune_perc_per_layer'], previous_masks, 
        self.args['train_biases'], self.args['train_bn'])

  def train(self, epochs, optimizer):

    self.model.to(device)

    # Train batch normalization from the first task. Next tasks use the same stats. !!!!!!
    if self.pruner.current_dataset_idx == 1:
      self.model.train()
    else:
      self.model.train_nobn()

    for epoch in range(epochs):
      for inputs, label in tqdm(self.train_data_loader, desc = 'Epoch {}/{}'.format(epoch + 1, epochs)):
        
        inputs = inputs.to(device)
        label = label.to(device)
        label = label % self.args['num_outputs']

        self.model.zero_grad()

        #Do forward-backward
        output = self.model(inputs)
        self.criterion(output, label).backward()

        #Set fixed param grads to 0.
        self.pruner.make_grads_zero()

        optimizer.step()

        #Set pruned weights to 0.
        self.pruner.make_pruned_zero()
        

    self.eval(self.pruner.current_dataset_idx)
    print('Finished Training...')
    print('-' * 16)


  def eval(self, dataset_idx, biases = None):

    self.pruner.apply_mask(dataset_idx)
    if biases is not None:
      self.pruner.restore_biases(biases)
    
    self.model.eval()
    eval_total = 0
    eval_correct = 0
    with torch.no_grad():
      for inputs, label in tqdm(self.test_data_loader, desc = 'Evaluation on task {}'.format(dataset_idx)):
        inputs = inputs.to(device)
        label = label.to(device)
        label = label % self.args['num_outputs']

        output = self.model(inputs)

        _, predictions = torch.max(output, dim = 1)
        eval_total += label.shape[0]
        eval_correct += int((predictions == label).sum())
      print("Valdation Accuracy: {:.4f}"
       .format(eval_correct / eval_total))

  def prune(self):
    print('Pre-prune evaluation:')
    self.eval(self.pruner.current_dataset_idx)

    self.pruner.prune()
    self.check(True)

    print("\nPost-prune evaluation:")
    self.eval(self.pruner.current_dataset_idx)
    
    if self.args['post_prune_epochs']:
      print('Doing post prune training...')
  
      optimizer = optim.Adam(self.model.parameters(), lr = self.args['lr'])
      self.train(self.args['post_prune_epochs'], optimizer)
    
    print('-' * 16)
    print('Pruning summary:')
    self.check(True)
    

  def check(self, verbose = True):
    """Checks what percent of each layer is pruned"""
    print('Checking Network...')
    for layer_idx, module in enumerate(self.model.shared.modules()):
      if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        weight = module.weight.data
        num_params = weight.numel()
        num_zero = weight.view(-1).eq(0).sum()
        if verbose:
           print('Layer #%d: Pruned %d/%d (%.2f%%)' %
                 (layer_idx, num_zero, num_params, 100 * num_zero / num_params))
  
  def save_model(self, savename):
        """Saves model to file."""
        base_model = self.model

        # Prepare the ckpt.
        ckpt = {
            'args': self.args,
            'previous_masks': self.pruner.current_masks,
            'model': base_model,
        }

        # Save to file.
        torch.save(ckpt, savename)
  
  def train_task(self):
    self.pruner.make_finetuning_mask()
    optimizer = optim.Adam(self.model.parameters(), lr = self.args['lr'])

    #Perform finetunning
    self.train(self.args['finetune_epochs'], optimizer)

# CIFAR-100 Example

In [None]:
trainloaders, testloaders = get_cifar100()
num_outputs = 20
lr = 1e-3
epochs = 10

In [None]:
model = PacknetResNet()
previous_masks = {}

# Filling masks with 0s, so every parameter is available for the first task
for module_idx, module in enumerate(model.shared.modules()):
  if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
    mask = torch.ByteTensor(module.weight.data.size()).fill_(0)
    mask = mask.to(device)
    previous_masks[module_idx] = mask

In [None]:
train_args = {"num_outputs" : num_outputs, "lr" : lr,  "finetune_epochs" : 7, "dataset" : "CIFAR100_", 
              "prune_perc_per_layer" : 0.7, "train_biases" : False, "train_bn" : False}

prune_args = {"num_outputs" : num_outputs, "lr" : lr * 0.1, "dataset" : "CIFAR100_", 
              "prune_perc_per_layer" : 0.7, "post_prune_epochs" : 3, "train_biases" : False, "train_bn" : False}

In [None]:
task_idx = 1
for trainloader, testloader in zip(trainloaders, testloaders):
  
  #train
  packnet = Packnet(train_args, model, previous_masks, trainloader, testloader)
  model.add_dataset(train_args['dataset'] + str(task_idx), train_args['num_outputs'])
  model.set_dataset(train_args['dataset'] + str(task_idx))
  model.to(device)
  packnet.train_task()

  #prune
  previous_masks = packnet.pruner.current_masks
  packnet = Packnet(prune_args, model, previous_masks, trainloader, testloader)
  packnet.prune()

  task_idx += 1

  previous_masks = packnet.pruner.current_masks

packnet.save_model(packnet_path)

In [None]:
task_idx = 1
for trainloader , testloader in zip(trainloaders, testloaders):
  ckpt = torch.load(packnet_path)
  model = ckpt['model']
  previous_masks = ckpt['previous_masks']

  model.set_dataset(train_args['dataset'] + str(task_idx))
  model.to(device)

  packnet = Packnet(train_args, model, previous_masks, trainloader, testloader)

  packnet.eval(task_idx)

  task_idx += 1

# Permuted MNIST Example

In [None]:
trainloaders, testloaders = get_permuted_mnist()
num_outputs = 10
lr = 1e-3
epochs = 5

In [None]:
model = PacknetResNet()
previous_masks = {}

# Filling masks with 0s, so every parameter is available for the first task
for module_idx, module in enumerate(model.shared.modules()):
  if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
    mask = torch.ByteTensor(module.weight.data.size()).fill_(0)
    mask = mask.to(device)
    previous_masks[module_idx] = mask

In [None]:
train_args = {"num_outputs" : num_ouputs, "lr" : lr,  "finetune_epochs" : 3, "dataset" : "PMNIST_", 
              "prune_perc_per_layer" : 0.7, "train_biases" : False, "train_bn" : False}

prune_args = {"num_outputs" : num_outputs, "lr" : lr * 0.1, "dataset" : "PMNIST_", 
              "prune_perc_per_layer" : 0.7, "post_prune_epochs" : 2,  "train_biases" : False, "train_bn" : False}

In [None]:
task_idx = 1
for trainloader, testloader in zip(trainloaders, testloaders):
  #train
  packnet = Packnet(train_args, model, previous_masks, trainloader, testloader)
  model.add_dataset(train_args['dataset'] + str(task_idx), train_args['num_outputs'])
  model.set_dataset(train_args['dataset'] + str(task_idx))
  model.to(device)
  packnet.train_task()

  #prune
  previous_masks = packnet.pruner.current_masks
  packnet = Packnet(prune_args, model, previous_masks, trainloader, testloader)
  packnet.prune()

  task_idx += 1

  previous_masks = packnet.pruner.current_masks

packnet.save_model(packnet_path)

In [None]:
task_idx = 1
for trainloader , testloader in zip(trainloaders, testloaders):
  ckpt = torch.load(packnet_path)
  model = ckpt['model']
  previous_masks = ckpt['previous_masks']

  model.set_dataset(train_args['dataset'] + str(task_idx))
  model.to(device)

  packnet = Packnet(train_args, model, previous_masks, trainloader, testloader)

  
  packnet.eval(task_idx)

  task_idx += 1

# Flowers-102 Example

In [None]:
trainloaders, testloaders = get_flowers102()
num_outputs = 17
lr = 1e-3
epochs = 20

In [None]:
model = PacknetResNet()
previous_masks = {}

# Filling masks with 0s, so every parameter is available for the first task
for module_idx, module in enumerate(model.shared.modules()):
  if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
    mask = torch.ByteTensor(module.weight.data.size()).fill_(0)
    mask = mask.to(device)
    previous_masks[module_idx] = mask

In [None]:
train_args = {"num_outputs" : num_outputs, "lr" : lr,  "finetune_epochs" : 15, "dataset" : "Flowers_", 
              "prune_perc_per_layer" : 0.7, "train_biases" : False, "train_bn" : False}

prune_args = {"num_outputs" : num_outputs, "lr" : lr * 0.1, "dataset" : "Flowers_", 
              "prune_perc_per_layer" : 0.7, "post_prune_epochs" : 5,  "train_biases" : False, "train_bn" : False}

In [None]:
task_idx = 1
for trainloader, testloader in zip(trainloaders, testloaders):
  #train
  packnet = Packnet(train_args, model, previous_masks, trainloader, testloader)
  model.add_dataset(train_args['dataset'] + str(task_idx), train_args['num_outputs'])
  model.set_dataset(train_args['dataset'] + str(task_idx))
  model.to(device)
  packnet.train_task()

  #prune
  previous_masks = packnet.pruner.current_masks
  packnet = Packnet(prune_args, model, previous_masks, trainloader, testloader)
  packnet.prune()

  task_idx += 1

  previous_masks = packnet.pruner.current_masks

packnet.save_model(packnet_path)

In [None]:
task_idx = 1
for trainloader , testloader in zip(trainloaders, testloaders):
  ckpt = torch.load(packnet_path)
  model = ckpt['model']
  previous_masks = ckpt['previous_masks']

  model.set_dataset(train_args['dataset'] + str(task_idx))
  model.to(device)

  packnet = Packnet(train_args, model, previous_masks, trainloader, testloader)
  packnet.eval(task_idx)

  task_idx += 1

# Diff Datasets Example

In [None]:
trainloaders, testloaders = get_diff_datasets()
num_outputs = [10, 100, 102]
lr = 1e-3
epochs = 10

In [None]:
model = PacknetResNet()
previous_masks = {}

# Filling masks with 0s, so every parameter is available for the first task
for module_idx, module in enumerate(model.shared.modules()):
  if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
    mask = torch.ByteTensor(module.weight.data.size()).fill_(0)
    mask = mask.to(device)
    previous_masks[module_idx] = mask

In [None]:
datasets = ['MNIST', 'CIFAR100', 'FLOWERS102']

train_args = {"lr" : lr, "finetune_epochs" : 7,  "prune_perc_per_layer" : 0.7, "train_biases" : False, "train_bn" : False}
prune_args = {"lr" : lr * 0.1, "post_prune_epochs": 3,  "prune_perc_per_layer" : 0.7, "train_biases" : False, "train_bn" : False}


In [None]:
task_idx = 1
for trainloader, testloader in zip(trainloaders, testloaders):
   
  train_args['num_outputs'] = num_outputs[task_idx - 1]
  train_args['dataset'] = datasets[task_idx - 1]
  
  #train
  packnet = Packnet(train_args, model, previous_masks, trainloader, testloader)
  model.add_dataset(train_args['dataset'], train_args['num_outputs'])
  model.set_dataset(train_args['dataset'])
  model.to(device)
  packnet.train_task()

  #prune
  prune_args['num_outputs'] = num_outputs[task_idx - 1]
  prune_args['dataset'] = datasets[task_idx - 1]

  previous_masks = packnet.pruner.current_masks
  packnet = Packnet(prune_args, model, previous_masks, trainloader, testloader)
  packnet.prune()

  task_idx += 1

  previous_masks = packnet.pruner.current_masks

packnet.save_model(packnet_path)

In [None]:
task_idx = 1
for trainloader , testloader in zip(trainloaders, testloaders):
  ckpt = torch.load(packnet_path)
  model = ckpt['model']
  previous_masks = ckpt['previous_masks']


  model.set_dataset(datasets[task_idx - 1])
  model.to(device)

  train_args['num_outputs'] = num_outputs[task_idx - 1]
  train_args['dataset'] = datasets[task_idx - 1]
  packnet = Packnet(train_args, model, previous_masks, trainloader, testloader)

  packnet.eval(task_idx)

  task_idx += 1