# Imports

In [None]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

from tqdm import tqdm

In [None]:
mnist_path = '/datasets/mnist'
cifar_path = '/datasets/cifar'
flower_path = '/datasets/flowers102'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Datasets

## Cifar-100

In [None]:
def get_cifar100(batch_size = 64):

  # Define the transform to apply to the data
  train_transform = transforms.Compose([
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

  test_transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

  # Load the CIFAR-100 dataset
  train_dataset = torchvision.datasets.CIFAR100(cifar_path, train=True, download=True, transform=train_transform)
  test_dataset = torchvision.datasets.CIFAR100(cifar_path, train=False, download=True, transform=test_transform)

  
  # Define the number of classes per task
  num_classes_per_task = 20
  num_tasks = 5
  
  train_task_loaders = []
  test_task_loaders = []

  for i in range(num_tasks):
    classes = list(range(i*num_classes_per_task, (i+1)*num_classes_per_task))
    train_task_dataset = torch.utils.data.Subset(train_dataset, [j for j in range(len(train_dataset)) if train_dataset[j][1] in classes])
    test_task_dataset = torch.utils.data.Subset(test_dataset, [j for j in range(len(test_dataset)) if test_dataset[j][1] in classes])
    train_loader = DataLoader(train_task_dataset, batch_size = batch_size, shuffle = True, num_workers = 2)
    test_loader = DataLoader(test_task_dataset, batch_size = batch_size, shuffle = False)

    train_task_loaders.append(train_loader)
    test_task_loaders.append(test_loader)
  
  return train_task_loaders, test_task_loaders

## Permuted MNIST

In [None]:
class PermuteMNISTTask:
    def __init__(self, permutation):
        self.permutation = permutation

    def __call__(self, x):
        x = x.view(-1, 32 * 32)
        x = x[:, self.permutation]
        return x.view(-1, 32, 32)

def get_permuted_mnist(batch_size = 64):

  # Define random permutations
  permutations = [torch.randperm(32 * 32) for i in range(5)]

  # Create transforms
  tasks = []
  for permutation in permutations:
      tasks.append(transforms.Compose([
          transforms.Grayscale(num_output_channels=3),
          torchvision.transforms.Resize(32),
          transforms.ToTensor(),
          PermuteMNISTTask(permutation)
      ]))              

  train_loaders = []
  test_loaders = []

  # Create tasks    
  for task in tasks:
    train_dataset = torchvision.datasets.MNIST(mnist_path, train=True, download=True, transform = task)
    test_dataset = torchvision.datasets.MNIST(mnist_path, train=False, download=True, transform = task)

    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = 2)
    test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

    train_loaders.append(train_loader)
    test_loaders.append(test_loader)

  return train_loaders, test_loaders

## Flowers 102

In [None]:
def get_flowers102(batch_size = 16):
  # Create transforms
  train_transform = transforms.Compose([
      transforms.Resize([256, 256]),
      transforms.CenterCrop(224),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

  test_transform = transforms.Compose([
      transforms.Resize([256, 256]),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
  
  train_dataset = torchvision.datasets.Flowers102(flower_path, split="train", download=True, transform=train_transform)
  test_dataset = torchvision.datasets.Flowers102(flower_path, split="test", download=True, transform=test_transform)

  num_classes_per_task = 17
  num_tasks = 6

  train_task_loaders = []
  test_task_loaders = []
  
  # Create tasks 
  for i in range(num_tasks):
    classes = list(range(i*num_classes_per_task, (i+1)*num_classes_per_task))
    train_task_dataset = torch.utils.data.Subset(train_dataset, [j for j in range(len(train_dataset)) if train_dataset[j][1] in classes])
    test_task_dataset = torch.utils.data.Subset(test_dataset, [j for j in range(len(test_dataset)) if test_dataset[j][1] in classes])
    train_loader = DataLoader(train_task_dataset, batch_size = batch_size, shuffle = True, num_workers = 2)
    test_loader = DataLoader(test_task_dataset, batch_size = batch_size, shuffle = False)

    train_task_loaders.append(train_loader)
    test_task_loaders.append(test_loader)

  return train_task_loaders, test_task_loaders

## Diff Datasets

In [None]:
def get_diff_datasets(batch_size = 32):
  """ Create an experiment with MNIST, CIFAR100 and Flowers102 as different tasks  """

  trainloaders = []
  testloaders = []

  #CIFAR100 and Flowers102 transforms
  train_transform = transforms.Compose([
      transforms.Resize([256, 256]),
      transforms.CenterCrop(224),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
  
  test_transform = transforms.Compose([
      transforms.Resize([256, 256]),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

  #MNIST transform
  mnist_train = transforms.Compose([
      transforms.Grayscale(num_output_channels=3),
      train_transform
  ])

  mnist_test = transforms.Compose([
      transforms.Grayscale(num_output_channels=3),
      test_transform
  ])

  # Load MNIST
  mnist_train = torchvision.datasets.MNIST(mnist_path, train=True, download=True, transform = mnist_train)
  mnist_test = torchvision.datasets.MNIST(mnist_path, train=False, download=True, transform = mnist_test)

  mnist_train_loader = DataLoader(mnist_train, batch_size = batch_size, shuffle = True, num_workers = 2)
  mnist_test_loader = DataLoader(mnist_test, batch_size = batch_size, shuffle = False)

  trainloaders.append(mnist_train_loader)
  testloaders.append(mnist_test_loader)

  #Load CIFAR100
  cifar_train = torchvision.datasets.CIFAR100(cifar_path, train=True, download=True, transform=train_transform)
  cifar_test = torchvision.datasets.CIFAR100(cifar_path, train=False, download=True, transform=test_transform)

  cifar_train_loader = DataLoader(cifar_train, batch_size = batch_size, shuffle = True, num_workers = 2)
  cifar_test_loader = DataLoader(cifar_test, batch_size = batch_size, shuffle = False)

  trainloaders.append(cifar_train_loader)
  testloaders.append(cifar_test_loader)

  # Load Flowers102
  flowers_train = torchvision.datasets.Flowers102(flower_path, split="train", download=True, transform=train_transform)
  flowers_test = torchvision.datasets.Flowers102(flower_path, split="test", download=True, transform=test_transform)

  flowers_train_loader = DataLoader(flowers_train, batch_size = batch_size, shuffle = True, num_workers = 2)
  flowers_test_loader = DataLoader(flowers_test, batch_size = batch_size, shuffle = False)
  
  trainloaders.append(flowers_train_loader)
  testloaders.append(flowers_test_loader)

  return trainloaders, testloaders

# PNN

## Blocks

In [None]:
class ProgBlock(nn.Module):
  def runBlock(self, x):
    raise NotImplementedError
  
  def runLateral(self, i, x):
    raise NotImplementedError
  
  def runActivation(self, x):
    raise NotImplementedError
  
  def isLateralized(self):
    return True

class ProgInertBlock(ProgBlock):
  def isLateralized(self):
    return False

In [None]:
class ProgDenseBlock(ProgBlock):
  """ A ProgBlock containing a single fully connected layer(nn.Linear) """

  def __init__(self, inSize, outSize, numLaterals, activation = nn.ReLU()):
    super().__init__()
    self.numLaterals = numLaterals
    self.inSize = inSize 
    self.outSize = outSize
    self.module = nn.Linear(inSize, outSize)
    self.laterals = nn.ModuleList([nn.Linear(inSize, outSize) for _ in range(numLaterals)])
    if activation is None: self.activation = (lambda x : x)
    else:                  self.activation = activation

  def runBlock(self, x):
    return self.module(x)

  def runLateral(self, i, x):
    lat = self.laterals[i]
    return lat(x)
  
  def runActivation(self, x):
    return self.activation(x)

In [None]:
class ProgDenseBNBlock(ProgBlock):
  """ A ProgBlock containing a single fully connected layer and batch normalization """
  def __init__(self, inSize, outSize, numLaterals, activation = nn.ReLU(), bnArgs = dict()):
    super().__init__()
    self.numLaterals = numLaterals
    self.inSize = inSize
    self.outSize = outSize
    self.module = nn.Linear(inSize, outSize)
    self.moduleBN = nn.BatchNorm1d(outSize, **bnArgs)
    self.laterals = nn.ModuleList([nn.Linear(inSize, outSize) for _ in range(numLaterals)])
    self.lateralBNs = nn.ModuleList([nn.BatchNorm1d(outSize, **bnArgs) for _ in range(numLaterals)])
    if activation is None: self.activation = (lambda x : x) 
    else:                  self.activation = activation
  
  def runBlock(self, x):
    return self.moduleBN(self.module(x))
  
  def runLateral(self, i, x):
    lat = self.laterals[i]
    bn = self.lateralBNs[i]
    return bn(lat(x))
  
  def runActivation(self, x):
    return self.activation(x)

In [None]:
class ProgConv2DBlock(ProgBlock):
  """ A ProgBlock containing a single Conv2D layer """

  def __init__(self, inSize, outSize, kernelSize, numLaterals, activation = nn.ReLU(), layerArgs = dict()):
    super().__init__()
    self.numLaterals = numLaterals
    self.inSize = inSize
    self.outSize = outSize
    self.kernSize = kernelSize
    self.module = nn.Conv2d(inSize, outSize, kernelSize, **layerArgs)
    self.laterals = nn.ModuleList([nn.Conv2d(inSize, outSize, kernelSize, **layerArgs) for _ in range(numLaterals)])
  
    if activation is None: self.activation = (lambda x : x)
    else:                  self.activation = activation
  
  def runBlock(self, x):
    return self.module(x)

  def runLateral(self, i, x):
    lat = self.laterals[i]
    return lat(x)
  
  def runActivation(self, x):
    return self.activation(x)

In [None]:
class ProgConv2DBNBlock(ProgBlock):
  """ A ProgBlock contaiing a single Conv2d and Batch Normalization """

  def __init__(self, inSize, outSize, kernelSize, numLaterals, activation = nn.ReLU(), layerArgs = dict(), bnArgs = dict()):
        super().__init__()

        self.numLaterals = numLaterals

        self.inSize = inSize
        self.outSize = outSize
        self.kernSize = kernelSize
        self.module = nn.Conv2d(inSize, outSize, kernelSize, **layerArgs)

        self.moduleBN = nn.BatchNorm2d(outSize, **bnArgs)
        self.laterals = nn.ModuleList([nn.Conv2d(inSize, outSize, kernelSize, **layerArgs) for _ in range(numLaterals)])
        self.lateralBNs = nn.ModuleList([nn.BatchNorm2d(outSize, **bnArgs) for _ in range(numLaterals)])

        if activation is None:   self.activation = (lambda x: x)
        else:                    self.activation = activation


  def runBlock(self, x):
        return self.moduleBN(self.module(x))

  def runLateral(self, i, x):
        lat = self.laterals[i]
        bn = self.lateralBNs[i]
        return bn(lat(x))

  def runActivation(self, x):
        return self.activation(x)

In [None]:
class ProgResnetBlock(ProgBlock):
  """ A ProgBlock with two Conv2d layers with batch normalization and a skip connection """

  def __init__(self, inSize, outSize, kernelSize, numLaterals, activation = nn.ReLU(), bnArgs = dict(), stride = 1, expansion = 1,  downsample = None):

      super().__init__()
      self.numLaterals = numLaterals
        
      self.expansion = expansion
      self.downsample = downsample

      self.inSize = inSize
      self.outSize = outSize
      self.kernSize = kernelSize

      self.conv1 = nn.Conv2d(inSize, outSize, kernelSize, stride = stride, padding = 1, bias = False)
      self.bn1 = nn.BatchNorm2d(outSize)

      self.conv2 = nn.Conv2d(outSize, outSize * self.expansion, kernelSize, padding = 1, bias = False)
      self.bn2 = nn.BatchNorm2d(outSize * self.expansion)

      self.laterals = nn.ModuleList([nn.Conv2d(inSize, outSize, kernelSize, stride = stride, padding = 1) for _ in range(numLaterals)])
      self.lateralBNs = nn.ModuleList([nn.BatchNorm2d(outSize) for _ in range(numLaterals)])

      if activation is None:   self.activation = (lambda x: x)
      else:                    self.activation = activation


  def runBlock(self, x):
      input = x
        
      out = self.conv1(x)
      out = self.bn1(out)
      out = self.activation(out)

      out = self.conv2(out)
      out = self.bn2(out)

      if self.downsample is not None:
        input = self.downsample(x)

      out += input

      return out

  def runLateral(self, i, x):
      lat = self.laterals[i]
      bn = self.lateralBNs[i]
      return bn(lat(x))

  def runActivation(self, x):
      return self.activation(x)

In [None]:
class ProgLambda(ProgInertBlock):
  """ A ProgBlock with no lateral connections that runs input through lambda functions """

  def __init__(self, module):
    super().__init__()
    self.module = module
  
  def runBlock(self, x):
    return self.module(x)
  
  def runActivation(self, x):
    return x

class ProgMaxPool(ProgInertBlock):
  """ A ProgBlock that runs input through Max Pool  """
  def __init__(self, kernel_size, stride = None,  padding = 0):
    super().__init__()
    self.module = nn.MaxPool2d(kernel_size, stride, padding)

  def runBlock(self, x):
    return self.module(x)
  
  def runActivation(self, x):
    return x

class ProgAdaptiveAvgPool(ProgInertBlock):
  """ A ProgBlock that runs input through Adaptive Average Pool """
  def __init__(self, h, w):
    super().__init__()
    self.module = nn.AdaptiveAvgPool2d((h, w))

  def runBlock(self, x):
    return self.module(x)
  
  def runActivation(self, x):
    return x 

## Column

In [None]:
class ProgColumn(nn.Module):
  """
  One of the PNN's columns. Outputs are stored to be available for next columns.
  """

  def __init__(self, colID, blockList, parentCols = []):
    super().__init__()
    self.colID = colID
    self.isFrozen = False
    self.parentCols = parentCols
    self.blocks = nn.ModuleList(blockList)
    self.numRows = len(blockList)
    self.lastOutputList = []
  
  def freeze(self, unfreeze = False):
    if not unfreeze: 
      self.isFrozen = True
      for param in self.parameters(): param.requires_grad = False
    else:
      self.isFrozen = False
      for param in self.parameters: param.requires_grad = True
  
  def forward(self, input):
    outputs = []
    x = input
    for row, block in enumerate(self.blocks):
      currOutput = block.runBlock(x)
      if row == 0 or len(self.parentCols) < 1 or not block.isLateralized():
        y = block.runActivation(currOutput)
      else: 
        for c, col in enumerate(self.parentCols):
          currOutput += block.runLateral(c, col.lastOutputList[row - 1])
        y = block.runActivation(currOutput)
      outputs.append(y)
      x = y
    self.lastOutputList = outputs
    return outputs[-1]

In [None]:
class ProgColumnGenerator:
  """ Class that generates automatically new columns with the same architecture """

  def generatorColumn(self, parentCols, msg = None):
    raise NotImplementedError
  

class Resnet18Generator(ProgColumnGenerator):
  """ ResNet-18 architecture columns """

  def __init__(self, num_classes):
    self.ids = 0
    self.num_classes = num_classes
  
  def generateColumn(self, parentCols, msg = None):
    cols = []
    cols.append(ProgConv2DBNBlock(3, 64, 7, 0, activation = nn.ReLU(), layerArgs = {"stride" : 2, "padding" : 3, "bias" : False}))
    cols.append(ProgMaxPool(2, 2))

    cols.append(ProgResnetBlock(64, 64, 3, len(parentCols), activation = nn.ReLU()))
    cols.append(ProgResnetBlock(64, 64, 3, len(parentCols), activation = nn.ReLU()))

    #downsample
    ds1 = self.downsample = nn.Sequential(
      nn.Conv2d(64, 128, kernel_size=1, stride= 2, bias=False),
      nn.BatchNorm2d(128),
    )

    cols.append(ProgResnetBlock(64, 128, 3, len(parentCols), activation = nn.ReLU(), stride = 2, downsample = ds1))
    cols.append(ProgResnetBlock(128, 128, 3, len(parentCols), activation = nn.ReLU()))

    #downsample
    ds2 = self.downsample = nn.Sequential(
      nn.Conv2d(128, 256, kernel_size=1, stride= 2, bias=False),
      nn.BatchNorm2d(256),
    )

    cols.append(ProgResnetBlock(128, 256, 3, len(parentCols), activation = nn.ReLU(), stride = 2, downsample = ds2))
    cols.append(ProgResnetBlock(256, 256, 3, len(parentCols), activation = nn.ReLU()))

    #downsample
    ds3 = self.downsample = nn.Sequential(
      nn.Conv2d(256, 512, kernel_size=1, stride= 2, bias=False),
      nn.BatchNorm2d(512),
    )

    cols.append(ProgResnetBlock(256, 512, 3, len(parentCols), activation = nn.ReLU(), stride = 2, downsample = ds3))
    cols.append(ProgResnetBlock(512, 512, 3, len(parentCols), activation = nn.ReLU()))
   
    cols.append(ProgAdaptiveAvgPool(1, 1))
    cols.append(ProgLambda(lambda x : torch.flatten(x, start_dim = 1)))
    cols.append(ProgDenseBlock(512, self.num_classes, len(parentCols), activation = None))

    return ProgColumn(self.__genID(), cols, parentCols = parentCols)
  
  def __genID(self):
    id = self.ids
    self.ids += 1
    return id 

## Net 

In [None]:
class Pnn(nn.Module):
  """ The main PNN class """
  def __init__(self, colGen = None):
    super().__init__()
    self.columns = nn.ModuleList()
    self.numCols = 0
    self.colMap = dict()
    self.colGen = colGen
  
  def addColumn(self, col = None, msg = None):
    if not col:
      parents = [colRef for colRef in self.columns]
      col = self.colGen.generateColumn(parents, msg)
    col = col.to(device)
    self.columns.append(col)
    self.colMap[col.colID] = self.numCols
    self.numCols += 1
    return col.colID
  
  def freezeColumn(self, id):
    col = self.columns[self.colMap[id]]
    col.freeze()
  
  def freezeAllColumns(self):
    for col in self.columns:
      col.freeze()
  
  def unfreezeColumn(self, id):
    col = self.columns[self.colMap[id]]
    col.freeze(unfreeze = True)
  
  def unfreezeAllColumns(self):
    for col in self.columns:
      col.freeze(unfreeze = True)
  
  def forward(self, id, x):
    colToOutput = self.colMap[id]
    for i, col in enumerate(self.columns):
      y = col(x)
      if i == colToOutput:
        return y 

  def getColumn(self, id):
    col = self.columns[self.colMap[id]]
    return col

  def begin_task(self):
   return self.addColumn()
  
  def end_task(self):
    self.freezeAllColumns()
  
  def trainTask(self, lr,  epochs, trainloader, testloader, num_outputs = 10):

    col = self.addColumn()
    self.columns[-1].train() #train only the last column

    optimizer = torch.optim.Adam(self.parameters(), lr = lr)

    for epoch in range(epochs):
      for inputs, labels in tqdm(trainloader, desc = 'Training Column {}, Epoch {}/{}'.format(col, epoch + 1, epochs)):
        inputs = inputs.to(device)
        labels = labels.to(device)
        labels = labels % num_outputs
       
        output = self(col, inputs)
        loss = F.cross_entropy(output, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    self.evalTask(col, testloader, num_outputs)

  def evalTask(self, task_idx, testloader, num_outputs = 10):
    self.eval()
    total = 0
    correct = 0
    with torch.no_grad():
      for inputs, labels in tqdm(testloader, desc = 'Evaluating Task {}'.format(task_idx)):
        inputs = inputs.to(device)
        labels = labels.to(device)
        labels = labels % num_outputs
        output = self(task_idx, inputs)

        _, predictions = torch.max(output, dim = 1)
        total += labels.shape[0]
        correct += int((predictions == labels).sum()) 

    print("Col: {}, Testing Accuracy: {:.4f}".format(task_idx, correct/total))

# CIFAR-100 Example

In [None]:
trainloaders, testloaders = get_cifar100()
num_outputs = 20
lr = 1e-3
epochs = 10

In [None]:
model = Pnn(colGen = Resnet18Generator(num_classes = num_outputs))
model.to(device)

for trainloader, testloader in zip(trainloaders, testloaders):
  model.trainTask(lr, epochs, trainloader, testloader, num_outputs = num_outputs)
  model.freezeAllColumns()
  

In [None]:
for task_idx, testloader in enumerate(testloaders):
  model.evalTask(task_idx, testloader, num_outputs = num_outputs)

# Permuted MNIST Example

In [None]:
trainloaders, testloaders = get_permuted_mnist()
num_outputs = 10
lr = 1e-3
epochs = 5

In [None]:
model = Pnn(colGen =  Resnet18Generator(num_classes = num_outputs))
model.to(device)

for trainloader, testloader in zip(trainloaders, testloaders):
  model.trainTask(lr, epochs, trainloader, testloader, num_outputs = num_outputs)
  model.freezeAllColumns()

In [None]:
for task_idx, testloader in enumerate(testloaders):
  model.evalTask(task_idx, testloader, num_outputs = num_outputs)

# Flowers-102

In [None]:
trainloaders, testloaders = get_flowers102()
num_outputs = 17
lr = 1e-3
epochs = 20

In [None]:
model = Pnn(colGen = Resnet18Generator(num_classes = num_outputs))
model.to(device)

for trainloader, testloader in zip(trainloaders, testloaders):
  model.trainTask(lr, epochs, trainloader, testloader, num_outputs = num_outputs)
  model.freezeAllColumns()

In [None]:
for task_idx, testloader in enumerate(testloaders):
  model.evalTask(task_idx, testloader, num_outputs = num_outputs)

# Diff Datasets 

In [None]:
trainloaders, testloaders = get_diff_datasets()
num_outputs = [10, 100, 102]
lr = 1e-3
epochs = 10

In [None]:
task_idx = 0

model = Pnn(colGen = Resnet18Generator(num_classes = num_outputs[task_idx]))
model.to(device)

for trainloader, testloader in zip(trainloaders, testloaders):
  model.trainTask(lr, epochs[task_idx], trainloader, testloader, num_outputs = num_outputs[task_idx])
  model.freezeAllColumns()

  if task_idx < 2:
    task_idx += 1
    model.colGen.num_classes = num_outputs[task_idx]

In [None]:
for task_idx, testloader in enumerate(testloaders):
  model.evalTask(task_idx, testloader, num_outputs = num_outputs[task_idx])