<a href="https://colab.research.google.com/github/HyperGlitch24/AMAI/blob/main/PNN_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Progressive Neural Networks (PNNs)

This notebook explores Progressive Neural Networks (PNNs) for Image classification. Unlike conventional NNs, PNNs usually assigns seperate columns for each new task, where each column denotes a neural network architecture. To leverage the information learned by the previous columns, PNNs use adapters to laterally connect the columns with each other. We will see in working how a multi-column Neural Network helps to sequentiall train on a set of sub-tasks without the problem of catastrophic forgetting.

## TO-DOs:
- Construct two sets of tasks: MNIST & Permuted Mnist
- Design PNNS
- Define two baselines
    - Trained only on target task (Permuted Mnist) in this case - Baseline 1
    - Pretrained on Source task, finetuned on target task
        - All the layers are frozen leaving the final output layer - Baseline 2
        - All the layers are trained - Baseline 3

    This cases are quite similar to what we have seen in the transfer learning tutorial.  

- Train and Evaluate on two contexts

In [None]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import copy
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import torch.optim as optim
import torchvision.models as models0
import torchvision
from torchvision.utils import make_grid

In [None]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Define your PNN

- First define a single PNN column. It can be any Neural network of your choice
- Define the PNN architecture where you have the possibility to add more than one column. Freeze one column. Train on the newly added column. Add lateral connections.

In [None]:
# PNN column
class PNN_column(nn.Module):
    def __init__(self, input_channels=1, output_classes=10):
        super(PNN_column, self).__init__()


    def forward(self, x):

        return out

In [None]:
# Progressive Neural Network definition
class PNN(nn.Module):
    def __init__(self, prev_model=None):
        super(PNN, self).__init__()

        """ TO-DO
        1. Initialize the column_task with the appropriate input channels and output classes.
        2. If prev_model is provided, copy its parameters to the column_task.
        3. Define the adapters for the previous model if it exists.
        """
        self.column_task = PNN_column()

    def forward(self, x):
        """ TO-DO
        1. Forward function to .
        2. If prev_model is provided, copy its parameters to the column_task.
        3. Define the adapters for the previous model if it exists.
        """
        return out

### Dataset Definition
- Create two contexts: MNIST and Permuted MNIST


In [1]:
# MNIST Dataloaders
transform = transforms.ToTensor()
mnist_trainset = datasets.MNIST(root='/app/src/Mnist', train=True, download=True, transform= transforms.ToTensor())
mnist_testset = datasets.MNIST(root='/app/src/Mnist', train=False, download=True, transform= transforms.ToTensor())

len(mnist_trainset), len(mnist_testset) # 60000, 10000
config = {'size': 28, 'channels': 1, 'classes': 10}

#@title Visualization functions
def multi_context_barplot(axis, accs, title=None):
    '''Generate barplot using the values in [accs].'''
    contexts = len(accs)
    axis.bar(range(contexts), accs, color='k')
    axis.set_ylabel('Testing Accuracy (%)')
    axis.set_xticks(range(contexts), [f'Context {i+1}' for i in range(contexts)])
    if title is not None:
        axis.set_title(title)

def plot_examples(axis, dataset, context_id=None):
    '''Plot 25 examples from [dataset].'''
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=25, shuffle=True)
    image_tensor, _ = next(iter(data_loader))
    image_grid = make_grid(image_tensor, nrow=5, pad_value=1) # pad_value=0 would give black borders
    axis.imshow(np.transpose(image_grid.numpy(), (1,2,0)))
    if context_id is not None:
        axis.set_title("Context {}".format(context_id+1))
    axis.axis('off')

# Function to apply a given permutation the pixels of an image.
def permutate_image_pixels(image, permutation):
    '''Permutate the pixels of [image] according to [permutation].'''

    return image

# Class to create a dataset with images that have all been transformed in the same way.
class TransformedDataset(torch.utils.data.Dataset):
    '''To modify an existing dataset with a transform.
    Useful for creating different permutations of MNIST without loading the data multiple times.'''

    def __init__(self, original_dataset, transform=None, target_transform=None):
        super().__init__()
        self.dataset = original_dataset
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        (input, target) = self.dataset[index]
        if self.transform:
            input = self.transform(input)
        if self.target_transform:
            target = self.target_transform(target)
        return (input, target)

import numpy as np
contexts = 2
permutations =
# Specify for each context the transformed train- and testset
train_datasets = []
test_datasets = []
for context_id, perm in enumerate(permutations):

# Visualize the contexts
figure, axis = plt.subplots(1, contexts, figsize=(3*contexts, 4))

for context_id in range(len(train_datasets)):
    plot_examples(axis[context_id], train_datasets[context_id], context_id=context_id)


SyntaxError: invalid syntax (ipython-input-1-3020491436.py, line 59)

### Define the training and the evaluation function

In [None]:
# Training function
def train(model, loader, optimizer, criterion, epochs=1):
    model.to(device)
    for epoch in range(epochs):
        model.train()
        for data, target in loader:

        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

In [None]:
# Evaluation function
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:

    return correct / total

### Now train the PNN network: first on the first context using the first column


In [None]:
# Train Task 1 (original MNIST)
model1 = PNN()
optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Training Task 1 (MNIST)")
train(model1, train_loader, optimizer1, criterion, epochs=3)
acc1 = evaluate(model1, test_loader)
print("Task 1 Test Accuracy:", acc1)

Training Task 1 (MNIST)
Epoch 1, Loss: 0.0116
Epoch 2, Loss: 0.0078
Epoch 3, Loss: 0.0078
Task 1 Test Accuracy: 0.988


### Then train on the second context using the second column

In [None]:
# Save the frozen model for Task 1
frozen_model = copy.deepcopy(model1.task_model)

In [None]:
# Train Task 2 (Permuted MNIST)
model2 = PNN(prev_model=frozen_model)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=0.001)

print("\nTraining Task 2 (Permuted MNIST)")
train(model2, train_loader2, optimizer2, criterion, epochs=3)
acc2 = evaluate(model2, test_loader2)
print("Task 2 Test Accuracy:", acc2)


Training Task 2 (Permuted MNIST)
Epoch 1, Loss: 0.0809
Epoch 2, Loss: 0.4536
Epoch 3, Loss: 0.2599
Task 2 Test Accuracy: 0.9626


## Train and Evaluate the baselines
- Now define all the baselines mentioned above
- Individually train on the second context
- Compare the results

In [None]:
""" TO-DO  """

## Baseline 1

"""
 Baseline 1: Train a new model for target task without reusing previous knowledge
"""


## Baseline 2

""" Baseline 2: Train a new model on the target task, but initialize it with the parameters of the previous model and freeze the previous model's parameters.
"""


## Baseline 3
""" Baseline 3: Train a new model on the target task, but initialize it with the parameters of the previous model and fine-tune all the model parameters.


### Plot the performance of PNNs and the baselines

In [None]:
""" TO-DO  """
