In [1]:
from lib.abstract_torch import get_device, get_loss, get_optimizer
from lib.dataloader import get_MNIST_loaders, get_FMNIST_loaders, get_CIFAR100_loaders, get_CIFAR10_loaders
from lib.model import ANN
from lib.train import train
from lib.test import test
from lib.growth_schedules import get_CL_schedule
from lib.visualize import (visualize_pathes,
                          visualize_statistical_reliability,
                          visualize_box_plot)               

import os
import numpy as np

In [2]:
device = get_device()

In [3]:
path = os.path.dirname(os.path.abspath("__file__"))
data_path = path + "\\data"

# MNIST

## Get data loaders

In [4]:
batch_size=128

In [5]:
# Get data for task 2
train_loader_0_to_4, val_loader_0_to_4, test_loader_0_to_4 = get_MNIST_loaders(data_path, range(5), batch_size)
train_loader_0_to_9, val_loader_0_to_9, test_loader_0_to_9 = get_MNIST_loaders(data_path, range(10), batch_size)

train_loaders, val_loaders, test_loaders = [], [], []

for i in range(10) :
    class_name = [i]
    train_loader, val_loader, test_loader = get_MNIST_loaders(data_path, [i], batch_size)
    train_loaders.append(train_loader)
    val_loaders.append(val_loader)
    test_loaders.append(test_loader)

## Random Initialization

In [6]:
init_name = "random"
savefig = "MNIST_random_CL"

### Define, train and test both root & target models

In [7]:
# Network's initial architecture
num_inputs = 28*28
num_hidden = 50
num_outputs = 10

# Network's final architecture
num_hidden_target = 100

# Loss & optimizer
loss_name = "MSE" # "CE"
optimizer_name = "Adam"

# Hyperparameters
num_epochs = 3
lr = 5e-3
growth_schedule = None

# Experiment parameters
num_repetitions = 1

#### Root model

In [8]:
test_acc_roots = []
for i in range(num_repetitions) :
    root_model = ANN(num_inputs, num_hidden, num_outputs).to(device)
    _ = train(root_model, num_outputs, growth_schedule, loss_name, optimizer_name, lr, train_loader_0_to_4, val_loader_0_to_4, num_epochs, batch_size, device, verbose=0)
    test_acc_root = test(root_model, test_loader_0_to_4, batch_size, device)
    test_acc_roots.append(test_acc_root)

100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.84s/it]


In [9]:
test_acc_roots

[97.85300000000001]

#### Target model

In [10]:
test_acc_targets = []
for i in range(num_repetitions) :
    target_model = ANN(num_inputs, num_hidden_target, num_outputs).to(device)
    _ = train(target_model, num_outputs, growth_schedule, loss_name, optimizer_name, lr, train_loader_0_to_9, val_loader_0_to_9, num_epochs, batch_size, device, verbose=0)
    test_acc_target = test(target_model, test_loader_0_to_9, batch_size, device)
    test_acc_targets.append(test_acc_target)

100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:10<00:00,  3.48s/it]


In [11]:
test_acc_targets

[93.73974358974357]

### Grow root model

In [25]:
# Growth parameters
num_neurons = 10
lr_root = 5e-3
lr_growth = 1e-3

In [26]:
growth_schedules = get_CL_schedule(num_neurons)

In [30]:
test_accs_matrices_repeted = []
for i in range(num_repetitions) :
    test_accs_matrices = []
    for growth_schedule in growth_schedules :
        # Define & pretrain the root model
        root_model = ANN(num_inputs, num_hidden, num_outputs).to(device)
        _ = train(root_model, num_outputs, None, loss_name, optimizer_name, lr_root, train_loader_0_to_4, val_loader_0_to_4, num_epochs, batch_size, device, verbose=0)
        # Initialize the matrix containing test accuracies on various task
        test_accs_matrix = np.zeros((6,10))
        for j in range(5) :
            test_accs_matrix[0,j] = round(test(root_model, test_loaders[j], batch_size, device),2)
        for i, train_loader, val_loader in zip(range(5,10), train_loaders[5:], val_loaders[5:]):
            _ = train(root_model,
                      num_outputs,
                      growth_schedule,
                      loss_name, 
                      optimizer_name, lr_growth, 
                      train_loaders[i], val_loaders[i], 
                      2, batch_size, 
                      device,
                      init_name=init_name,
                      verbose=0)
            for j in range(i+1) :
                test_accs_matrix[i-4,j] = round(test(root_model, test_loaders[j], batch_size, device),2)
        test_accs_matrices.append(test_accs_matrix)
    test_accs_matrices_repeted.append(test_accs_matrices)

100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.91s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.85it/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.59it/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.52it/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.63it/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.61it/s]


In [31]:
print(test_accs_matrices_repeted[0][0])

[[99.55 99.02 96.39 95.54 98.89  0.    0.    0.    0.    0.  ]
 [97.1  98.73 90.04 37.83 79.13 76.69  0.    0.    0.    0.  ]
 [51.    3.12 23.73  4.91 13.95  0.   99.55  0.    0.    0.  ]
 [ 0.78  1.56  0.1   6.92  0.11  0.    0.33 92.19  0.    0.  ]
 [ 0.    0.    1.17  7.03  3.35  0.   26.34 40.43 15.74  0.  ]
 [ 0.    0.    3.22  1.    0.78  0.    7.03  2.24  0.11 93.53]]


In [29]:
root_model.named_parameters

<bound method Module.named_parameters of ANN(
  (fc1): Linear(in_features=784, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=50, bias=True)
  (fc3): Linear(in_features=50, out_features=10, bias=True)
)>

### Visualize results

In [16]:
visualize_pathes(num_hidden, num_hidden_target, growth_schedules, test_accs, savefig)

NameError: name 'test_accs' is not defined

In [None]:
visualize_statistical_reliability (test_accs_repeted, test_accs, test_acc_roots, test_acc_targets,
                                   free_lim=True, savefig=savefig)

In [None]:
visualize_box_plot(test_accs, test_acc_root, test_acc_target, savefig)