In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms

import os
from tqdm import tqdm
import numpy as np

from scipy.sparse.linalg import svds

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
# Data parameter
batch_size=128

# Network's initial architecture
num_inputs = 28*28
num_hidden = 10
num_outputs = 10

# Network's final architecture
num_hidden_target = 10

In [4]:
path = os.path.dirname(os.path.abspath("__file__"))

data_path = path + "\\data"

## Load MNIST data

In [5]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

def get_MNIST_loaders(path, class_names, batch_size) :

    # load MNIST 
    mnist_train = datasets.MNIST(root=path, train=True, download=False, transform=transform)
    mnist_test = datasets.MNIST(root=path, train=False, download=False, transform=transform)


    # create a mask to filter indices for each label
    train_mask = torch.tensor([label in class_names for label in mnist_train.targets])
    test_mask = torch.tensor([label in class_names for label in mnist_test.targets])

    # Create Subset datasets for train, validation, and test
    train_dataset = Subset(mnist_train, torch.where(train_mask)[0])
    test_dataset = Subset(mnist_test, torch.where(test_mask)[0])

    # split train into train & validation
    train_size = int(0.7 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return train_loader, val_loader, test_loader

In [6]:
train_loader, val_loader, test_loader = get_MNIST_loaders(data_path, [i for i in range(10)], batch_size)

## Define model

In [7]:
class ANN (nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs):
        super().__init__()
        
        self.fc1 = nn.Linear(num_inputs,num_hidden)
        self.fc2 = nn.Linear(num_hidden, num_hidden)
        self.fc3 = nn.Linear(num_hidden, num_outputs)
        
        self.activation = torch.sigmoid ## nn.ReLU()
        
    def forward(self, x) :
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        x = self.activation(x)
        x = self.fc3(x)
        x = self.activation(x)
        
        return x
    
    def add_neurons (self, fc1_weight_grad, fc1_bias_grad, fc2_weight_grad, num_neurons, device) :
        # Add weights and biases to the layer of interest (fc1)
        num_in_1, num_out_1 = self.fc2.in_features, self.fc2.out_features
        num_in_2, num_out_2 = self.fc3.in_features, self.fc3.out_features

        # PRE-LAYER
        # Set parameters
        self.fc2.weight = nn.Parameter(torch.cat((self.fc2.weight,torch.zeros(num_neurons, num_in_1).to(device))))
        self.fc2.bias = nn.Parameter(torch.cat((self.fc2.bias, torch.zeros(num_neurons).to(device))))

        # Set gradients
        self.fc2.weight.grad = nn.Parameter(torch.cat((fc1_weight_grad,torch.zeros(num_neurons, num_in_1).to(device)), dim=0))
        self.fc2.bias.grad = nn.Parameter(torch.cat((fc1_bias_grad, torch.zeros(num_neurons).to(device)), dim=0))

        # POST-LAYER
        # Set parameters
        self.fc3.weight = nn.Parameter(torch.cat((self.fc3.weight, torch.zeros(num_out_2,num_neurons).to(device)), dim=1))

        # Set gradients
        self.fc3.weight.grad = torch.cat((fc2_weight_grad, torch.zeros(num_out_2, num_neurons).to(device)), dim=1)
        print((self.fc2.weight.grad != 0).sum())

## Training pipeline

In [8]:
def count_all_parameters(model) :
    num_params = 0
    for param_name, param in model.named_parameters():
        num_param = torch.numel(param)
        print(param_name, ":", num_param)
        num_params += num_param
    return num_params

In [9]:
def get_batch_accuracy(model, data, targets, batch_size):
    output = model(data.view(batch_size, -1))
    idx = output.argmax(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    return round(acc*100,2)

In [10]:
def compute_loss(model, data, targets, loss, loss_name, batch_size) :
    if loss_name == "CE":
        y = model(data.view(batch_size, -1))
        loss_val = loss(y, targets)
    else :
        y = model(data.view(batch_size, -1))
        one_hot_targets = nn.functional.one_hot(targets, num_classes=10).to(y.dtype)
        loss_val = loss(y, one_hot_targets)
    return y, loss_val

In [11]:
def compute_val (model, loss, loss_name, val_loader, val_loss_hist, val_acc_hist, epoch, batch_size, device, print_shit=False) :
    model.eval()
    val_data, val_targets = next(iter(val_loader))
    val_data = val_data.to(device)
    val_targets = val_targets.to(device)
    
    # Forward path
    y, val_loss_val = compute_loss(model, val_data, val_targets, loss, loss_name, batch_size)
    val_loss_hist.append(val_loss_val.item())
    
    # ACCURACY
    if print_shit :
        print(f"Epoch {epoch}")
        # print(f"Train Set Loss: {train_loss_hist[epoch]:.2f}")
        print(f"Val Set Loss: {val_loss_hist[epoch]:.2f}")
        print("\n")
    val_acc_hist.append(get_batch_accuracy(model, val_data, val_targets, batch_size))
    
    
    return val_loss_hist, val_acc_hist

GradMax offers a way to initialize the added neurons :
- $W_l^{new}$ set to $0$ (cf hypothesis between eq (8) and (9)) ♠
- $\frac{\partial L}{\partial W_{l}^{new}}$ set to $W_{l+1}^{new,T} \mathbb{E}_D\left[\frac{\partial L}{\partial z_{l+1}}h_{l-1}^T\right]$ even though (9) suggests $W_{l+1}^{new,T} \frac{\partial L}{\partial z_{l+1}}h_{l-1}^T$ ♠
- $B_l^{new}$ set to 0 (not mentioned in the paper, "zeros or ones" according to layers.py file in the code) ♠
- $\frac{\partial L}{\partial B_{l}^{new}}$ set to 0 (not mentioned in the paper) ♠
- $W_{l+1}^{new}$ set as the top $k$ left-singular vectors of the matrix $\mathbb{E}_D\left[\frac{\partial L}{\partial z_{l+1}}h_{l-1}^T\right]$ and scaling them by $\frac{c}{||(\sigma_1,...,\sigma_k)||}$ (where $\sigma_i$
is the $i$-th largest singular value) ♠
- $\frac{\partial L}{\partial W_{l+1}^{new}}$ set to $0$ (eq (10)) ♠

On veut calculer $\mathbb{E}_D\left[\frac{\partial L}{\partial z_{l+1}}h_{l-1}^T\right]$.
Détail :
- $\frac{\partial L}{\partial z_{l+1}}$ : $[10,1]$ ou $[num_{hidden},1]$ (c'est le gradient calculé dans le layer suivant)
- $h_{l-1}^T$ : $[1,784]$ (c'est simplement l'input x)
Donc en fait, on veut juste la 

On sait calculer obtenir ces deux quantités, il suffit de faire la moyenne sur les données vue depuis le dernier ajout d'un neurone (tout le training set si on ajoute des neurones à la fin de chaque epoch, ou tout le batch si on ajoute des neurones à la fin de batchs)

Impémentation :
- On a besoin de savoir pendant quel layer on veut grow dans la training loop, pour savoir quel gradient on doit stocker.

In [12]:
def add_neurons (fc1, fc2, num_neurons, grow_matrix, c, device) :
    # Add weights and biases to the layer of interest (fc1)
    num_in_1, num_out_1 = fc1.in_features, fc1.out_features
    new_fc1 = nn.Linear(in_features=num_in_1, out_features=num_out_1 + num_neurons).to(device)
    
    # Solve optimization problem (11)
    u, s, vh = svds(grow_matrix.cpu().detach().numpy(), k=num_neurons, return_singular_vectors=True)
    #print(fc1, (grow_matrix == 0).sum())
    #print("u :", u.shape)
    #print("vh :", vh.shape)
    eigenvals, eigenvecs = (s**2), u[::-1]
    #print(s)
    scaler = c / np.sqrt(eigenvals.sum())
    scaler = 1
    #print(scaler)
    new_weights_next_layer = (1/scaler)*torch.tensor(eigenvecs[:, :num_neurons].copy()).to(device)
    
    with torch.no_grad():
        # Set parameters
        new_fc1.weight[:num_out_1, :] = fc1.weight
        new_fc1.bias[:num_out_1] = fc1.bias
        new_fc1.weight[num_out_1:, :] = torch.zeros(num_neurons, num_in_1)
        new_fc1.bias[num_out_1:] = torch.zeros(num_neurons)
        # Set gradients
        #print("W_l+1^new transpose :", new_weights_next_layer.t().shape)
        #print("grow_matrix :", grow_matrix.shape)
        #print("new_grad :", fc1.weight.grad.shape)
        # new_fc1.weight.grad =  # torch.cat((fc1.weight.grad, add_weight_grads_1), dim=0)
        
        # add_weight_grads_1 = torch.mm(new_weights_next_layer.t(), grow_matrix).to(device)
        add_weight_grads_1 = torch.zeros(num_neurons, num_in_1).to(device)
        new_fc1.weight.grad = torch.cat((fc1.weight.grad,add_weight_grads_1), dim=0)
        
        add_bias_grad_1 = torch.zeros(num_neurons).to(device)
        new_fc1.bias.grad = torch.cat((fc1.bias.grad, add_bias_grad_1), dim=0)
        
    # Add weights to the following layer (fc2) so it matches fc1 size 
    num_in_2, num_out_2 = fc2.in_features, fc2.out_features
    new_fc2 = nn.Linear(in_features=num_in_2 + num_neurons, out_features=num_out_2).to(device)
    with torch.no_grad():
        new_fc2.weight[:, :num_in_2] = fc2.weight
        #print("new_weights :", new_fc2.weight[:, num_in_2:].shape)
        #print("potential :", new_weights_next_layer.shape)
        new_weights_next_layer = torch.zeros(num_out_2,num_neurons)
        new_fc2.weight[:, num_in_2:] = new_weights_next_layer # torch.zeros(num_out_2,num_neurons)
        # Set gradients
        add_weight_grads_2 = torch.zeros(num_out_2, num_neurons).to(device)
        new_fc2.weight.grad = torch.cat((fc2.weight.grad, add_weight_grads_2), dim=1)
    
    return new_fc1, new_fc2.to(device)

In [14]:
def register_hooks(model, pre_layer, post_layer):
    activation = []
    grad = []

    def forward_hook(module, x, y):
        activation.append(y)
    def backward_hook(module, grad_input, grad_output) :
        grad.append(grad_output)
    
    forward_hook_handle = None
    if pre_layer :
        forward_hook_handle = pre_layer.register_forward_hook(forward_hook)
        
    backward_hook_handle = post_layer.register_full_backward_hook(backward_hook)

    return activation, grad, forward_hook_handle, backward_hook_handle

In [72]:
def train (model, growth_schedule, loss, loss_name, optimizer, train_loader, val_loader, 
           num_epochs, batch_size, device, c = 1, print_num_params = 0) :
    train_loss_hist, val_loss_hist = [], []
    train_acc_hist, val_acc_hist = [], []
    
    # Epoch training loop
    for epoch in tqdm(range(num_epochs)):
        
        train_batch = iter(train_loader)
        loss_epoch = 0
        batch_index = 0
        
        # Batch training loop
        for i, (data, targets) in enumerate(train_batch):
            # Get data from the batch
            data = data.to(device)
            targets = targets.to(device)
            
            
            # Initialize the matrix on which we will perform GradMax
            if (growth_schedule is not None) and (batch_index%50 == 0) :
            
                
                # Remove previous hooks in case they are some
                for h in hooks_list :
                    h.remove()
                
                # Print the number of parameters
                if print_num_params :
                    count_all_parameters(model)
                
                layer_name, num_neurons = next(growth_schedule)
                if layer_name == "fc1" :
                    matrix_to_SVD = torch.zeros(model.fc2.out_features, model.fc1.in_features).to(device)
                elif layer_name == "fc2" :
                    matrix_to_SVD = torch.zeros(model.fc3.out_features, model.fc2.in_features).to(device)
                    #print("matrix_to_SVD :", matrix_to_SVD.shape)

            # forward pass
            model.train()
            
            # Get the gradient of the layer after the one we grow, for GradMax computation
            if (growth_schedule is not None) : #and (batch_index%50 == 0) :
                if layer_name == "fc1" :
                    
                    h = data.view(batch_size, -1)
                    # register hooks
                    _, grad, _, backward_hook_handle = register_hooks(model, None, model.fc2)
                    hooks_list.append(backward_hook_handle)
                    
                    # Forward path
                    y, loss_val = compute_loss(model, data, targets, loss, loss_name, batch_size)
                    
                    # Gradient calculation + weight update
                    #output_grad = torch.zeros(torch.Size([])).requires_grad_(True).to(device)
                    optimizer.zero_grad()
                    loss_val.backward() # Pass output_grad
                    optimizer.step()
                    
                    # Compute the matrix to which we will apply SVD
                    if batch_index%50 == 0 :
                        #print("output_grad : ", output_grad)
                        #print("grad shape :", grad[0][0].t()[0])
                        #print("h shape :", h.shape)
                        #print("h zeros :", (h==0).sum())
                        #print("matrix :", matrix_to_SVD)
                        pass
                    #matrix_to_SVD += torch.mm(grad[0][0].t(),h) / batch_size
                    
                elif layer_name == "fc2" :
                    # register hooks
                    h, grad, forward_hook_handle, backward_hook_handle = register_hooks(model, model.fc1, model.fc3)
                    hooks_list.append(forward_hook_handle)
                    hooks_list.append(backward_hook_handle)
                    
                    # Forward path
                    y, loss_val = compute_loss(model, data, targets, loss, loss_name, batch_size)
                    
                    # Gradient calculation + weight update
                    #output_grad = torch.ones(torch.Size([])).requires_grad_(True)
                    optimizer.zero_grad()
                    loss_val.backward() # Pass output_grad
                    optimizer.step()
                    
                    # Compute the matrix to which we will apply SVD
                    if batch_index%50 == 0 :
                        #print("output_grad : ", output_grad)
                        #print("grad shape :", grad[0][0].t()[0])
                        #print("h shape :", (h[0].shape))
                        #print("h zeros :", (h[0]==0).sum())
                        #print("h :", h[0][0])
                        #print("matrix :", matrix_to_SVD)
                        #print("w_fc1 :", model.fc2.weight[:,0])
                        pass
                    #matrix_to_SVD += torch.mm(grad[0][0].t(),h[0]) / batch_size
            
            else :
                # Forward path
                y, loss_val = compute_loss(model, data, targets, loss, loss_name, batch_size)
                
                # Gradient calculation + weight update
                optimizer.zero_grad()
                loss_val.backward()
                optimizer.step()
            
            loss_epoch += loss_val.item()
            
            
            # Add neurons
            if (growth_schedule is not None) and (batch_index%50 == 0) :
                #with torch.no_grad():
                    # Solve optimization problem (11)
                    #print("Shape of matrix_to_SVD :", matrix_to_SVD.shape)
                    # matrix_to_SVD = matrix_to_SVD.t()

                    #if layer_name == "fc1" :
                    #    model.fc1, model.fc2 = add_neurons(model.fc1, model.fc2, num_neurons, matrix_to_SVD, c, device)
                    #elif layer_name == "fc2" :
                    #    model.fc2, model.fc3 = add_neurons(model.fc2, model.fc3, num_neurons, matrix_to_SVD, c, device)
                    
                    fc1_weight_grad = model.fc2.weight.grad
                    fc1_bias_grad = model.fc2.bias.grad
                    fc2_weight_grad = model.fc3.weight.grad
                    model.add_neurons(fc1_weight_grad, fc1_bias_grad, fc2_weight_grad, num_neurons, device)
                    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, maximize=False)
            
            # Test set
            if batch_index%50 == 49:
                with torch.no_grad():
                    # Train data
                    # Store loss and acc histories for future plotting
                    train_loss_hist.append(loss_epoch/len(train_loader))
                    train_acc_hist.append(get_batch_accuracy(model, data, targets, batch_size))
                    
                    # Val data
                    val_loss_hist, val_acc_hist = compute_val (model,
                                                               loss,
                                                               loss_name,
                                                               val_loader, 
                                                               val_loss_hist, 
                                                               val_acc_hist,
                                                               epoch,
                                                               batch_size,
                                                               device,
                                                               print_shit=False)
                    
                    
            batch_index += 1

            if batch_index%50 == 0:
                print(f'{batch_index} batches used in epoch {epoch}')

        
    output = [train_loss_hist, train_acc_hist, val_loss_hist, val_acc_hist]
    return output

## Train the target model for comparisons

In [73]:
target_model = ANN(num_inputs, num_hidden_target, num_outputs).to(device)

In [74]:
loss = nn.MSELoss() # nn.CrossEntropyLoss()
loss_name = "MSE" # "CE"
optimizer = torch.optim.Adam(target_model.parameters(), lr=5e-3, maximize=False)
num_epochs = 3

In [75]:
growth_schedule = None
hooks_list= []

In [76]:
output = train(target_model, growth_schedule, loss, loss_name, optimizer, train_loader, val_loader, num_epochs, batch_size, device)

  0%|                                                                                            | 0/3 [00:01<?, ?it/s]

50 batches used in epoch 0





KeyboardInterrupt: 

In [None]:
output

## Train & Grow the model

In [77]:
root_model = ANN(num_inputs, num_hidden, num_outputs).to(device)

In [78]:
loss = nn.MSELoss()
#loss = nn.CrossEntropyLoss()
loss_name = "MSE" 
#loss_name = "CE"
optimizer = torch.optim.Adam(root_model.parameters(), lr=1e-3)
num_epochs = 3

In [79]:
num_neurons = 9
growth_schedule = iter([["fc1", num_neurons],["fc2", num_neurons],["fc1", num_neurons],
                        ["fc2", num_neurons],["fc1", num_neurons],["fc2", num_neurons],
                        ["fc1", num_neurons],["fc2", num_neurons],["fc1", num_neurons],
                        ["fc2", num_neurons],["fc1", num_neurons],["fc2", num_neurons],
                        ["fc1", num_neurons],["fc2", num_neurons],["fc1", num_neurons],
                        ["fc2", num_neurons],["fc1", num_neurons],["fc2", num_neurons],
                        ["fc1", num_neurons],["fc2", num_neurons],["fc1", num_neurons]])
c = 1

In [80]:
hooks_list= []

In [81]:
output = train(root_model, growth_schedule, loss, loss_name, optimizer, train_loader, val_loader, num_epochs, batch_size, device, c)

  0%|                                                                                            | 0/3 [00:00<?, ?it/s]

tensor(100, device='cuda:0')
50 batches used in epoch 0
tensor(190, device='cuda:0')
100 batches used in epoch 0
tensor(280, device='cuda:0')
150 batches used in epoch 0
tensor(370, device='cuda:0')
200 batches used in epoch 0
tensor(460, device='cuda:0')
250 batches used in epoch 0
tensor(550, device='cuda:0')
300 batches used in epoch 0
tensor(640, device='cuda:0')


 33%|████████████████████████████                                                        | 1/3 [00:09<00:18,  9.43s/it]

tensor(730, device='cuda:0')
50 batches used in epoch 1
tensor(820, device='cuda:0')
100 batches used in epoch 1
tensor(910, device='cuda:0')
150 batches used in epoch 1
tensor(1000, device='cuda:0')
200 batches used in epoch 1
tensor(1090, device='cuda:0')
250 batches used in epoch 1
tensor(1180, device='cuda:0')
300 batches used in epoch 1
tensor(1270, device='cuda:0')


 67%|████████████████████████████████████████████████████████                            | 2/3 [00:18<00:09,  9.33s/it]

tensor(1360, device='cuda:0')
50 batches used in epoch 2
tensor(1450, device='cuda:0')
100 batches used in epoch 2
tensor(1540, device='cuda:0')
150 batches used in epoch 2
tensor(1630, device='cuda:0')
200 batches used in epoch 2
tensor(1720, device='cuda:0')
250 batches used in epoch 2
tensor(1810, device='cuda:0')
300 batches used in epoch 2
tensor(1900, device='cuda:0')


100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:27<00:00,  9.30s/it]


In [82]:
output

[[0.019659467563941713,
  0.03322672337384486,
  0.045729304845558434,
  0.0563097451813519,
  0.06491421621928854,
  0.07178669707940483,
  0.004854535808923041,
  0.008826362554029358,
  0.012372124343883337,
  0.015586675988033233,
  0.018640646287921545,
  0.021570300896334032,
  0.002770292862295741,
  0.005379031159597017,
  0.007945245327768712,
  0.010499694401670883,
  0.013105801198767817,
  0.015756988230661104],
 [11.72,
  21.09,
  47.66,
  64.06,
  69.53,
  83.59,
  82.03,
  86.72,
  83.59,
  89.06,
  88.28,
  89.84,
  89.06,
  89.06,
  90.62,
  91.41,
  91.41,
  90.62],
 [0.09009817987680435,
  0.08709118515253067,
  0.07664433866739273,
  0.06534033268690109,
  0.05255601555109024,
  0.04278464615345001,
  0.03154897689819336,
  0.028091518208384514,
  0.026284297928214073,
  0.02581351436674595,
  0.026900852099061012,
  0.023041391745209694,
  0.02156229503452778,
  0.022510869428515434,
  0.02308960072696209,
  0.024956492707133293,
  0.020847737789154053,
  0.0233158

In [27]:
model_path = path + f"\\SNN_LoRA/ICL5_state_dict_task_{1}.pth"