In [1]:
import torch
import torch.nn as nn
from pyhessian import hessian
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from torchsummary import summary
import scipy as sp

epochs = 200
n_tasks = 1
L = 0
gamma0 = 1
widths = [128]
device = 'mps'
gen = torch.Generator(device=device)
gen.manual_seed(123)
batch = 20

class MLP(nn.Module):
            def __init__(self, w, L, param, gam):
                super(MLP, self).__init__()
                self.w = w
                if param =='ntk':
                    self.gamma = gam
                    self.in_scale = 784**0.5
                    self.out_scale = self.w**0.5*self.gamma
                elif param == 'mup': 
                     self.gamma = gam*self.w**0.5
                     self.in_scale = 784**0.5
                     self.out_scale = self.w**0.5*self.gamma
                elif param == 'sp':
                     self.gamma = 1
                     self.in_scale = 1
                     self.out_scale = 1

                self.fc1 = nn.Linear(784, self.w, bias=False)
                self.fc2 = nn.Linear(self.w, 10, bias=False)
                self.relu = nn.ReLU()
                self.L = L

            def forward(self, x):
                h1 = self.fc1(x)/self.in_scale
                h1act = self.relu(h1)
                h2 = self.fc2(h1act)/self.out_scale

                return h2
            
@torch.no_grad()
def init_weights(m):
    if type(m) == nn.Linear:
        m.weight.normal_()

def permut_row(x, perm):
            return x[perm]

        # ---------------------- START DATA -------------------------
data = pd.read_csv('~/data/MNIST/mnist_train.csv')
test = pd.read_csv('~/data/MNIST/mnist_test.csv')
#data = data[data['label'].isin([0, 1])]
#test = test[test['label'].isin([0, 1])]
X = torch.tensor(data.drop('label', axis = 1).to_numpy(), device=device)/255
X_test = torch.tensor(test.drop('label', axis = 1).to_numpy(), device=device)/255
X = X[:20]

Y_temp = torch.tensor(data['label'].to_numpy(), device=device)
Y = torch.eye(10, device=device)[Y_temp]
Y = Y[:20]

Y_temp = torch.tensor(test['label'].to_numpy(), device=device)
Y_test = torch.eye(10, device=device)[Y_temp]

tasks = [X]
tasks_test = [X_test]

for _ in range(n_tasks):
        perm = np.random.permutation(X.shape[1])
        tasks.append( torch.tensor(np.apply_along_axis(permut_row, axis = 1, arr=X.cpu(), perm=perm)).to(device) )
        tasks_test.append(torch.tensor(np.apply_along_axis(permut_row, axis = 1, arr=X_test.cpu(), perm=perm)).to(device))


def top_eigen(model, loss, X, Y, prt=False):

            hess_comp = hessian(model, loss, (X,Y) )
            top_eigenvalues, top_eigenvector = hess_comp.eigenvalues()
                
            return top_eigenvalues[-1] , top_eigenvector

def overlap(model, inputs, targets):
    
    gradients = torch.cat([param.grad.view(-1) for param in model.parameters()])
    params = torch.cat([param.data.view(-1) for param in model.parameters()])

    norm = torch.norm(params)

    def loss_fn(params):
        idx = 0
        layers = []
        for param in model.parameters():
            param_numel = param.numel()
            layers.append(params[idx:idx + param_numel].view_as(param))
            idx += param_numel
        relu = nn.ReLU()
        outputs = relu(inputs@layers[0].T) @ layers[1].T
        return MSE(outputs, targets)

    hvp = torch.autograd.functional.hvp(loss_fn, params, gradients)[1]
    return (torch.dot(hvp,gradients)/(torch.norm(hvp) * torch.norm(gradients))).item(), norm.item()


In [2]:
save_out = False

for regime in ['ntk']:
    for N in widths:

        loss_hist = []
        lam = []
        acc = []
        all = []
        norm = []
        res1 = []
        res2 = []
        
        mlp = MLP(N,L,regime, gamma0)

        if regime == 'ntk' or regime == 'mup':
            mlp = mlp.apply(init_weights)
            
        summary(mlp, (1,784))
        mlp = mlp.to(device)
        
        optimizer = torch.optim.SGD(mlp.parameters(), lr= mlp.gamma**2)
        eos = 2/mlp.gamma**2
     
        MSE = nn.MSELoss()

        for t,Xt in enumerate(tasks):        
                for epoch in range(epochs):

                        running_loss = 0.0
                        for i in range(len(Xt)//batch):

                            # Batch of training 
                            ix = torch.randint(0, len(X), (batch,), generator=gen, device=device)

                            ixc = torch.randint(0, len(X), (1024,), generator=gen, device=device)

                            lt = []
                            for s in range(t+1):
                                sharp, eigen = top_eigen(mlp, MSE, tasks[s][ixc], Y[ixc])
                                lt.append(sharp)
                            lam.append(lt)    

                            optimizer.zero_grad()

                            out = mlp(Xt[ix])
                            loss = MSE(out, Y[ix])

                            #res1.append( list((torch.sum(mlp(tasks[0][ix]) - Y[ix], dim=1)).detach().cpu()) )
                            #res2.append( list((torch.sum(mlp(tasks[1][ix]) - Y[ix], dim=1)).detach().cpu()) )
                            
                            loss.backward()
                            running_loss += loss.item()

                            over, no = overlap(mlp, X[ix], Y[ix])

                            all.append(over)
                            norm.append(no)
                            
                            optimizer.step()
                            loss_hist.append(loss.item())

                            #print(f'task {t} : (epoch: {epoch}), sample: {batch*(i+1)}, ---> train loss = {loss.item():.4f}')

                print(f'Finished Training task{t}, train loss: {running_loss/batch}')
                
                acct = []
                for s in range(t+1):
                    acct.append( (torch.sum(torch.argmax(mlp(tasks_test[s]), dim=1) == torch.argmax(Y_test, dim=1))/len(Y_test)).item() )  
                acc.append(acct) 
                
        if save_out:        
            with open(f'/Users/alessandrobreccia/Desktop/THESIS/data/lamda{N}_{regime}.txt', 'w') as file:
            
                for lst in lam:
                    file.write(' '.join(map(str, lst)) + ' ')
        
            with open(f'/Users/alessandrobreccia/Desktop/THESIS/data/acc{N}_{regime}.txt', 'w') as file:

                for lst in acc:
                    file.write(' '.join(map(str, lst)) + ' ')
        
            with open(f'/Users/alessandrobreccia/Desktop/THESIS/data/overlap{N}_{regime}.txt', 'w') as file:

                for a in all:
                    file.write(str(a) + ' ')
        
            with open(f'/Users/alessandrobreccia/Desktop/THESIS/data/norm{N}_{regime}.txt', 'w') as file:

                for n in norm:
                    file.write(str(n) + ' ')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1               [-1, 1, 128]         100,352
              ReLU-2               [-1, 1, 128]               0
            Linear-3                [-1, 1, 10]           1,280
Total params: 101,632
Trainable params: 101,632
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.39
Estimated Total Size (MB): 0.39
----------------------------------------------------------------


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Finished Training task0, train loss: 0.0043152239173650745
Finished Training task1, train loss: 0.003935420885682106


In [None]:
def SHARP():
    T = epochs*len(X)//batch
    colors = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'gray', 'orange', 'purple']
    plt.figure(figsize=(20,10))
    plt.title(f'Sharpness evolution with net of L={L}, lr={2/eos}, batch={batch}')
    N = 128
    for regime,ls in zip(['sp','ntk','mup'],['-',':']):

        lam_rec = []

        with open(f'/Users/alessandrobreccia/Desktop/THESIS/data/lamda{N}_{regime}.txt', 'r') as file:
            lines = file.readlines()

        for line in lines:
            lst = list(map(float, line.strip(',').split()))
            lam_rec.append(lst)
                
        a = np.array(lam_rec).T

        sh1 = a[0:T]
        sh2 = a[T:3*T]
        sh3 = a[3*T:6*T]
        sh4 = a[6*T:10*T]
    
        for i,sh in enumerate([sh1,sh2,sh3,sh4]):
            plt.axvline(i*T, color='r', linestyle='dotted')
            for j,row in enumerate(sh.reshape(T,i+1).T):
                plt.plot(np.arange(i*T,(i+1)*T,1), row, color=colors[j], linestyle=ls)

    plt.ylabel('Sharpness')
    plt.xlabel('Epochs')
    #plt.axhline(eos, color='black', linestyle='dotted')
    plt.grid()
    plt.show()

SHARP()

In [None]:
def ACC():
    T = epochs*len(X)//batch
    colors = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'gray', 'orange', 'purple']
    N = 128
    for regime,ls in zip(['sp','ntk','mup'],['o','*','+']):

        acc_rec = []

        with open(f'/Users/alessandrobreccia/Desktop/THESIS/data/acc{N}_{regime}.txt', 'r') as file:
            lines = file.readlines()

        for line in lines:
            lst = list(map(float, line.strip(',').split()))
            acc_rec.append(lst)
                
        a = np.array(acc_rec).T

        accs = [a[:1],a[1:3],a[3:6],a[6:10]]

        for i in range(n_tasks):
                plt.plot(range(i,n_tasks+1),[a[i] for a in accs[i:]], color=colors[i], marker=ls)
        plt.plot(n_tasks,accs[-1][-1], color=colors[-1], marker=ls )
        
    plt.ylabel('Accuracy')
    plt.xlabel('Tasks')
    plt.grid()
    plt.show()
ACC()

In [None]:
def OVERLAP():
    T = epochs*len(X)//batch
    colors = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'gray', 'orange', 'purple']
    
    plt.figure(figsize=(20,10))
    plt.title(f'Overlap between Hessian of task 1 and gradient at each task')
    N = 128
    for regime,c in zip(['sp','ntk','mup'],[0,1,2]):

        over_rec = []

        with open(f'/Users/alessandrobreccia/Desktop/THESIS/data/overlap{N}_{regime}.txt', 'r') as file:
            lines = file.readlines()

        for line in lines:
            lst = list(map(float, line.strip(',').split()))
            over_rec.append(lst)
                
        a = np.array(over_rec).T

        sh1 = a[0:T]
        sh2 = a[T:2*T]
        sh3 = a[2*T:3*T]
        sh4 = a[3*T:4*T]

        for i,sh in enumerate([sh1,sh2,sh3,sh4]):
            plt.axvline(i*T, color='r', linestyle='dotted')
            for j,row in enumerate(sh.T):
                plt.plot(np.arange(i*T,(i+1)*T,1), row, color=colors[c], label= regime)

    plt.ylabel('Sharpness')
    plt.xlabel('Epochs')
    plt.legend()
    plt.grid()
    plt.show()

OVERLAP()

In [None]:
def NORM():
    T = epochs*len(X)//batch
    colors = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'gray', 'orange', 'purple']
    
    plt.figure(figsize=(20,10))
    plt.title(f'Weigths norm of the models')
    N = 128
    for regime,c in zip(['sp','ntk','mup'],[0,1,2]):

        norm_rec = []

        with open(f'/Users/alessandrobreccia/Desktop/THESIS/data/norm{N}_{regime}.txt', 'r') as file:
            lines = file.readlines()

        for line in lines:
            lst = list(map(float, line.strip(',').split()))
            norm_rec.append(lst)
                
        a = np.array(norm_rec).T

        sh1 = a[0:T]
        sh2 = a[T:2*T]
        sh3 = a[2*T:3*T]
        sh4 = a[3*T:4*T]

        for i,sh in enumerate([sh1,sh2,sh3,sh4]):
            plt.axvline(i*T, color='r', linestyle='dotted')
            for j,row in enumerate(sh.T):
                plt.plot(np.arange(i*T,(i+1)*T,1), row, color=colors[c], label=regime)

    plt.ylabel('Sharpness')
    plt.xlabel('Epochs')
    plt.grid()
    plt.legend()
    plt.show()

NORM()

In [None]:
HLT

In [None]:
a = np.array(res1)
b = np.array(res2)

plt.plot(a)
plt.axvline(epochs)
plt.show()

plt.plot(b)
plt.axvline(epochs)
plt.show()

import jax.numpy as jnp
from jax import random
from neural_tangents import stax
import scipy as sc

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(128), stax.Relu(),
    stax.Dense(10),
)

key1, key2 = random.split(random.PRNGKey(1))

_, params = init_fn(key1, input_shape=X.shape)

ntk22 = kernel_fn(jnp.array(tasks[1].cpu()),jnp.array(tasks[1].cpu()), 'ntk')
ntk12 = kernel_fn(jnp.array(tasks[0].cpu()),jnp.array(tasks[1].cpu()), 'ntk')
inv_ntk22 = sc.linalg.inv(ntk22)

g = mlp.gamma

def delta1(t,delta10 , delta20):
    term_time = (np.eye(len(X)) - sc.linalg.expm(-np.array(ntk22)*t/g))
    kerns_term = ntk12 @ inv_ntk22
    d = np.array(delta10) - np.array(delta20) @ kerns_term @ term_time
    return d

p = []
for t in np.linspace(0,1,200):
    p.append(delta1(t,a[200],b[200]))

x = np.array(p)
plt.plot(np.mean(x, axis = 1))
plt.plot(np.mean(a[200:], axis= 1) )
plt.show()