In [11]:
from lib.abstract_torch import get_device
from lib.dataloader import get_task_loaders
from lib.models import ANN
from lib.growth_schedules import get_CL_schedule
from lib.train import train_ANN

import os
import numpy

In [8]:
train_vanilla = True

In [2]:
device = get_device()

In [6]:
path = os.path.dirname(os.path.abspath("__file__"))
data_path = path + "/data"

In [7]:
# Data parameter
batch_size=128

# Network Architecture
num_inputs = 28*28
num_hidden_root = 10
num_hidden_target = 100
num_outputs = 10

# Train parameters
loss_name = "CE"
optimizer_name = "Adam"

# Hyperparameters
num_epochs = 3
lr = 5e-3

## Reproducibility

In [None]:
random_seed = 88
permutation_random_seeds = list(range(10))

## Get p-MNIST loaders

In [None]:
train_loaders_list, val_loaders_list, test_loaders_list = [], [], []
for random_seed in permutation_random_seeds :
    train_loader, val_loader, test_loader = get_task_loaders(data_path, batch_size, random_seed, train_percentage=0.8, download=False)
    train_loaders_list += [train_loader]
    val_loaders_list += [val_loader]
    test_loaders_list += [test_loader]

## Grow from root model

In [None]:
# Growth parameters
num_neurons = 9
lr_root = 2e-3

In [None]:
growth_schedule = get_CL_schedule(num_neurons)

In [None]:
# Define root model
root_model = ANN(num_inputs, num_hidden_root, num_outputs).to(device)

# Train root model on task 0
train_ANN(root_model, loss_name, optimizer_name, lr, train_loaders_list[0], num_epochs, batch_size, device, random_seed)

# Initialize the matrix containing test accuracies on various task
test_accs_matrix = np.zeros((6,10))
for j in range(5) :
    test_accs_matrix[0,j] = round(test(root_model, test_loaders[j], batch_size, device),2)
for i, train_loader, val_loader in zip(range(5,6), train_loaders[5:], val_loaders[5:]):
    print("iteration :", i)
    lr_growth = 5*1e-4
    _ = train(root_model,
                num_outputs,
                growth_schedule[i-5],
                loss_name, 
                optimizer_name, lr_growth, 
                train_loaders[i], val_loaders[i], 
                2, batch_size, 
                device,
                init_name=init_name,
                verbose=0)
    for j in range(i+1) :
        test_accs_matrix[i-4,j] = round(test(root_model, test_loaders[j], batch_size, device),2)

