In [219]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import numpy as np
from sklearn.decomposition import PCA
import time
from mpl_toolkits.mplot3d import Axes3D
import scipy

## Generate the dataset: 
#### X = n_fd x n_sa x n_ni x n_da
n_fd: number of parameters in the plasticity rule + 1  
n_da: number of datasets presented to the pca network  
n_sa: number of samples per dataset  
n_ni: number of neurons in the pca network (= number of dimensions of the ambient space of the datasets)

In [220]:
def generate_datasets_Gaussian(n_samples, n_datasets, n_finite_differences, D, dev):
    #Xi taken from N(0,transp(Q)*D*Q)
    #D diagonal, len(D) must be odd
    
    X = torch.zeros((n_samples, len(D), n_datasets), dtype=torch.double, device=dev)
    for dataset_num in range(n_datasets):
        A = np.random.rand(len(D),len(D))
        Q, _ = np.linalg.qr(A)   
        if np.linalg.det(Q) < 0:
            Q = -Q
        Cov = np.matmul(np.transpose(Q),np.matmul(D,Q)) 
        for sample_num in range(n_samples):
            X[sample_num, :, dataset_num] = torch.tensor(np.random.multivariate_normal([0 for i in range(len(D))],Cov))
    
    #compute & store principal vectors to compare to PCA network performance afterwards
    pc1 = np.zeros((len(D),n_datasets))
    for d_num in range(n_datasets):
        pca = PCA(n_components=2)
        pca.fit(X[:,:,d_num].to('cpu').numpy())
        pc1[:,d_num] = pca.components_[0]  
    
    return(X.repeat(n_finite_differences,1,1,1), torch.tensor(pc1, dtype=torch.double, device=dev))    

## Polynomial expansion of the plasticity rule
#### Parallelized to update n_fd plasticity rules (compute all dimensions of the grad at once)

In [221]:
def G(X,Y,W,A,eta,dev):
    # parametrized plasticity rule
    # eta = learning rate
    # dev = cpu or gpu
    # A coeffs of the n_fd plasticity rules considered
    
    DW = torch.zeros(X.size(),dtype=torch.double,device=dev)
    ct = 0
    for x in [torch.ones(X.size(),dtype=torch.double, device=dev), X, torch.mul(X,X)]: # (elementwise mult)
        for y in [torch.ones(Y.size(),dtype=torch.double, device=dev), Y, torch.mul(Y,Y)]:
            for w in [torch.ones(W.size(),dtype=torch.double, device=dev), W, torch.mul(W,W)]:
                DW += torch.einsum("a,asnd,asd,and->asnd",A[:,ct],x,y,w)
                ct += 1
    return(eta*DW/X.shape[1]) #divide by n_sa: independent of number of samples per dataset

def G_1D(A, w, x, y):
    #w, x, y scalars, returns a scalar
    dw = 0
    ct = 0
    for pre in [1, x, x**2]:
        for post in [1, y, y**2]:
            for weight in [1, w, w**2]:
                dw += A[ct]*pre*post*weight
                ct += 1
    return(dw)

def G_oja(w, x, y):
    return(x*y - w*(y**2))

## PCA network:
#### n_neurons (activity X) projecting on 1 neuron (activity y) with weights W.
#### Parallelized on n_samples_per_datasets, n_datasets and n_fd (number of different plasticity rules run at the same time)
W: n_fd x n_ni x n_da,   Y: n_fd x n_sa x n_da

In [222]:
class pca_net(): 
    
    def __init__(self, n_epochs, n_datasets, n_samples_per_dataset, n_fd, n_neurons, a_blow, b_blow, dev):
        self.n_sa = n_samples_per_dataset
        self.n_ni = n_neurons
        self.n_ep = n_epochs
        self.n_da = n_datasets
        self.n_fd = n_fd
        self.a_blow = a_blow
        self.b_blow = b_blow
        
        self.Y = torch.zeros((self.n_fd, self.n_sa, self.n_da),dtype=torch.double, device=dev)
        self.W = torch.zeros((self.n_fd, self.n_ni, self.n_da),dtype=torch.double, device=dev)
        
        self.blow_up =  [self.n_ep for i in range(self.n_fd)] #epoch at which the nets blew up (n_ep = no blow up) 
        self.not_blown = [i for i in range(self.n_fd)] #dimensions which did not blow up yet
        
    def forward(self, X):
        self.Y = torch.einsum("asnd,and->asd", X, self.W)
    
    def train(self, A_full, eta, X, pc1, dev):
        
        ###### W initilisation, choose between 3 options ######
        aux = 0.1*torch.randn([self.n_ni, self.n_da], dtype=torch.double, device=dev); self.W = aux.repeat(self.n_fd,1,1)
#         self.W = torch.load("Winit1.pt"); self.W = self.W[0, :, :].repeat(self.n_fd, 1, 1);
#         self.W = pc1.clone().detach().repeat(self.n_fd,1,1) #initialise W at the solution
        #######################################################
        
        for epoch_num in range(self.n_ep):
            self.forward(X)
            self.W += torch.sum(G(X, self.Y, self.W, A_full, eta, dev), 1)
            
            ###### handle blow-ups
            if not torch.equal(self.W,self.W): #nan != nan: there is a nan somewhere (ie blow up). Fast to compute
                dim_to_remove = []
                for fd_num in self.not_blown:
                    if not torch.equal(self.W[fd_num, :, :], self.W[fd_num, :, :]):
                        print("epoch_num =" + str(epoch_num) + ";  blow_up on dim " + str(fd_num))
                        self.blow_up[fd_num] = epoch_num
                        dim_to_remove.append(fd_num)
                self.not_blown = [x for x in self.not_blown if x not in dim_to_remove]
                
        for fd_num in self.not_blown: #check for "almost blow ups" at the end of training (almost nan), which can screw up the grad
            dim_to_remove = []
            if torch.max(self.W[fd_num]).item() > 100000/self.n_sa:
                print("epoch_num =" + str(self.n_ep) + ";  blow_up (to come) on dim " + str(fd_num))
                self.blow_up[fd_num] = self.n_ep - 1
                dim_to_remove.append(fd_num)
        self.not_blown = [x for x in self.not_blown if x not in dim_to_remove]
        
    def score(self, pc1):
        has_blown_up = False # has any fd perturbations of the network blown up during training
        err_mean_pca = [0]*self.n_fd
        for fd_num in range(self.n_fd):
            if self.blow_up[fd_num] < self.n_ep: # blow up on this perturbation of A
                has_blown_up = True
                err_mean_pca[fd_num] = self.a_blow*(self.n_ep - self.blow_up[fd_num])/(self.n_ep) + self.b_blow
            else:
                ##############################################
                aux = self.n_da*(self.n_ni**(1/3))/(3**(1/3))
                for d_num in range(self.n_da):
                    err_mean_pca[fd_num] += (min(torch.norm(pc1[:, d_num] - self.W[fd_num, :, d_num]).item(), torch.norm(pc1[:, d_num] + self.W[fd_num, :, d_num]).item()))/aux
        return(has_blown_up, err_mean_pca)

## Gradient descent on the plasticity rule
Gradient computed with finite differences, ADAM optimizer (modified)

In [223]:
class learningG_finite_diff():
       
    def __init__(self, D, n_samples_per_dataset, n_epochs, n_datasets, A, eta, h, meta_eta, beta1, beta2, epsilon, reg, a_blow, b_blow, n_it_back):
        self.n_sa = n_samples_per_dataset
        self.n_ni = len(D)
        self.n_ep = n_epochs
        self.n_da = n_datasets
        self.n_fd = 2*len(A) + 1
        
        ###### Generation of the training datasets ######
        self.X,self.pc1 = generate_datasets_Gaussian(self.n_sa, self.n_da, self.n_fd, D, dev)
#         self.X = torch.load("X1.pt"); self.pc1 = torch.load("pcX1.pt"); self.X = self.X[0, :, :, :].repeat(self.n_fd, 1, 1, 1)
        ##################################################
        
        self.A = A.clone().detach()
        self.eta = eta
        self.m_eta = meta_eta
        self.h = h
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.current_meta_it = 0
        self.m = torch.zeros(len(self.A), dtype=torch.double, device=dev)
        self.v = torch.zeros(len(self.A), dtype=torch.double, device=dev)
        self.reg = reg
        self.a_blow = a_blow
        self.b_blow = b_blow
        self.last_blow_up = 0
        self.n_it_back = n_it_back
        
        # Generating the perturbations on each dim of A for finite differences method: calc on both sides
        self.A_h = torch.zeros(self.n_fd, len(A), dtype=torch.double, device=dev)
        for dim_num in range(len(A)):
            self.A_h[2*dim_num + 1, dim_num] = h
            self.A_h[2*dim_num + 2, dim_num] = (-h)
        
        # for plotting
        self.A_hist = self.A.clone().detach()
        self.loss_hist = []
        self.angle_hist = []
    
    def get_grad(self, dev):
        grad = torch.zeros(self.A.shape, dtype=torch.double)
        net = pca_net(self.n_ep, self.n_da, self.n_sa, self.n_fd, self.n_ni, self.a_blow, self.b_blow, dev)
        A_full = self.A + self.A_h
        has_blown_up ,loss_no_L1, losses = self.loss(A_full, net)
#         print("losses " + str(losses))
        g = torch.tensor(([(losses[2*i + 1] - losses[2*i + 2])/(2*self.h) for i in range(len(self.A))]), dtype=torch.double, device=dev)
#         print("grad " + str(g))
        return(has_blown_up, loss_no_L1[0], g)
    
    def loss(self, A_full, net):
        net.train(A_full, self.eta, self.X, self.pc1, dev)
        l1 = np.multiply([torch.norm(A_full[i, :],p = 1).item() for i in range(self.n_fd)],self.reg)
        has_blown_up, loss_no_l1 = net.score(self.pc1)
        return(has_blown_up, loss_no_l1, loss_no_l1 + l1)

    def train(self, A_oja, dev, n_meta_it): #ADAM optimizer with momentum reset
        initial_meta_it = self.current_meta_it
        y = A_oja.cpu().numpy() # for angle calculation after
        
        while self.current_meta_it < (initial_meta_it + n_meta_it):   
            has_blown_up, loss, g = self.get_grad(dev)
#             print("current loss no l1 " + str(loss))
            
            if ((not has_blown_up) or (has_blown_up and (self.current_meta_it - self.last_blow_up < 11))):
                # no momentum reset if no blow up at this meta_it or blow up following a too recent blow up 
                self.m = self.beta1*self.m + (1-self.beta1)*g #momentum
                self.v = self.beta2*self.v + (1-self.beta2)*torch.mul(g,g) #2nd moment estimate

            else:
                # reset momentum/second moment estimate and restart n_it_back meta_it before
                self.current_meta_it -= self.n_it_back
                self.last_blow_up = self.current_meta_it
                self.A_hist = self.A_hist[0: (self.A_hist.size()[0] - 27*self.n_it_back)]
                self.A = self.A_hist[-27:].clone().detach()
                self.loss_hist = self.loss_hist[0: (len(self.loss_hist) - self.n_it_back)]
                self.angle_hist = self.angle_hist[0: (len(self.angle_hist) - self.n_it_back)]
                print("momentum reset at meta_it " + np.str(self.current_meta_it))
                
                has_blown_up, loss, g = self.get_grad(dev)
                self.m = (1-self.beta1)*g
                self.v = (1-self.beta2)*torch.mul(g,g)
                   
            ###### bias correction or not ######
#             m_hat = self.m/(1-(self.beta1**(self.current_meta_it+1)))
#             v_hat = self.v/(1-(self.beta2**(self.current_meta_it+1)))
            m_hat = self.m
            v_hat = self.v
#             print("m " + str(m_hat))
#             print("v " + str(v_hat))
            
            self.A -= self.m_eta*torch.div(m_hat,torch.sqrt(v_hat) + self.epsilon) #update A
            
            if not torch.equal(self.A,self.A):
                print("a nan made it inside A")
                return()            
            self.A_hist = torch.cat((self.A_hist, self.A),0)
            self.loss_hist.append(loss)
            x = self.A.cpu().numpy();
            self.angle_hist.append(np.arccos(np.dot(x, y)/((np.linalg.norm(x)*np.linalg.norm(y))))*180/(np.pi))
            if (self.current_meta_it % 10 == 0):
                print("iteration " + str(self.current_meta_it+1) + "/" + str(initial_meta_it + n_meta_it))
                print("current_loss (without L1 term): " + str(loss))  
            self.current_meta_it += 1

    def plot(self):
        plt.figure(1)
        for dim_num in range(len(self.A)):
            plt.plot([self.A_hist[dim_num + len(self.A)*i].item() for i in range(self.current_meta_it)], linewidth = 8)
        plt.title("evolution of coefficients of A through training")
        plt.show()
        
        plt.figure(2)
        plt.plot(self.loss_hist, linewidth = 4)
        plt.title("loss across iterations")
        plt.show()
        
        plt.figure(3)
        plt.plot(self.angle_hist, linewidth = 4)
        plt.title("angle A - A_oja across iterations")
        plt.show()

## Simulation

#### 1/ parameters

In [224]:
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

###### Synthetic dataset parameters, D has to be odd dimension
D3 = np.diag([2,1,0]); D3bis = np.diag([1,0.5,0])
D5 = np.diag([2,1,1,0,0]); D5bis = np.diag([1,0.7,0.2,0.1,0])
D11 = np.diag([5,4,3,2,1,0,0,0,0,0,0]); D11bis = np.diag([1,0.8,0.5,0.3,0.1,0,0,0,0,0,0])
D39 = np.diag([10,9,9,9,8,8,8,7,7,7,5,5,5,5,5,5,5,5,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0])
D39bis = np.diag([1,0.9,0.9,0.9,0.8,0.8,0.8,0.7,0.7,0.7,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0,0,0,0,0,0,0,0,0,0,0])

###### PCA_net parameters
n_samples_per_dataset = 100; n_epochs = 200; eta = 1/20

###### Meta optimisation parameters
n_meta_it = 10; n_datasets = 100; h = 1/200; meta_eta = 1/200; beta1 = 0.9; beta2 = 0.999; 
epsilon = 1e-6; reg = 0.1; a_blow = 1; b_blow = 2; n_it_back = 2

###### Initialisation
A_oja = torch.tensor([0,0,0,0,0,0,0,(-1),0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0], dtype=torch.double, device=dev)
A_0 = torch.tensor([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], dtype=torch.double, device=dev)
A_close_oja    = torch.tensor([-0.0810, -0.0161, -0.1892, -0.1068,  0.0860,  0.0048,  0.0403, -1.1123,
         0.0957, -0.0231, -0.0791,  0.1590,  1.0008,  0.1013, -0.2789,  0.0728,
         0.0148,  0.2366, -0.1421,  0.0984, -0.0901,  0.0740,  0.1227,  0.0633,
        -0.0386, -0.0694, -0.0459], dtype=torch.double, device=dev)
A_rand_0 = torch.tensor([ 0.0867,  0.0778, -0.0047, -0.1911, -0.1838, -0.0798, -0.0325,  0.0380,
         0.0030,  0.0540, -0.0596,  0.0153, -0.1660, -0.0241,  0.1270, -0.0591,
        -0.0492, -0.0512, -0.0270, -0.0016, -0.0454,  0.0275,  0.1039, -0.0373,
        -0.0933, -0.1817,  0.0762], dtype=torch.double, device=dev)
A_unstable = torch.tensor([ 0.1945, -0.4187,  1.3444, -0.0039,  0.2078, -0.9590,  0.7205,  1.3984,
        -0.4893, -0.9633, -0.7834,  0.0169, -1.0988, -1.6190, -1.1393, -0.8860,
         0.9665, -0.9679, -0.1661, -0.4697, -0.3577, -0.6227, -0.1183,  0.4971,
        -0.2486, -1.8236,  0.2660], dtype=torch.double, device=dev)
A_adam1 = torch.tensor([ 0.0102, -0.0200, -0.0091, -0.0054, -0.0050, -0.0050,  0.1968, -0.3316,
        -0.0042, -0.0057, -0.0037, -0.0048,  0.2370, -0.1286, -0.0873,  0.0192,
        -0.0199,  0.0208, -0.0545, -0.2052, -0.0135, -0.0332,  0.0178, -0.0219,
        -0.1949, -0.8068, -0.3349], dtype=torch.double, device=dev)
A_adam2 = torch.tensor([-5.5102e-02, -1.4344e-01,  9.0817e-03, -1.0121e-02, -2.4525e-03,
        -5.3187e-03,  1.3737e-01, -2.7457e-01,  2.6577e-02, -6.3599e-03,
        -1.6456e-04, -4.2544e-03,  5.1275e-01, -1.9923e-02, -2.3343e-02,
        -7.9909e-03, -4.7050e-03, -3.3497e-03, -9.8408e-02,  3.6230e-02,
        -2.2974e-03, -1.1924e-02, -7.9329e-03, -7.6932e-03, -6.9461e-02,
        -2.5508e-01, -1.5929e-01], dtype=torch.double, device=dev) #A_adam2 is iteration 120 from plast_rule17
A_rand = 0.01*torch.randn(A_oja.size(),dtype=torch.double, device=dev)
Atest = torch.tensor([ 0.0206, -0.0080,  0.0019, -0.0092,  0.0112,  0.0097,  0.0056, -0.4870,
        -0.0165, -0.0193, -0.0083,  0.0021,  0.6872,  0.0121,  0.0276, -0.0076,
        -0.0091, -0.0112, -0.0181,  0.0651,  0.0073, -0.0084, -0.0043, -0.0025,
        -0.0166, -0.1224, -0.0085], device='cuda:0', dtype=torch.float64)
Atest2 = torch.tensor([-4.6946e-02, -2.4606e-02,  7.6413e-03,  1.3415e-02, -3.9182e-04,
         5.4507e-05,  2.7690e-02, -4.9433e-01,  9.2513e-02,  1.8933e-02,
         7.0540e-03,  8.6338e-04,  5.9564e-01,  8.2658e-03,  3.2073e-02,
         1.3348e-02, -2.8721e-03,  4.3265e-04, -3.3670e-02,  6.5483e-02,
         6.8184e-03,  2.5619e-03, -4.4377e-03, -1.2937e-03,  1.6686e-02,
        -7.9009e-02, -8.7798e-02], device='cuda:0', dtype=torch.float64)

#### 2/ simulation

In [225]:
start = time.time()

###### START A NEW SIMULATION #######################
plast_rule = learningG_finite_diff(D, n_samples_per_dataset, n_epochs, n_datasets, A_rand_0, eta, h, meta_eta, beta1, beta2, epsilon, reg, a_blow, b_blow, n_it_back)
print("datasets generated (" + str(np.round(time.time() - start,2)) + "s)")
#####################################################

###### CONTINUE WORKING ON A PRE_TRAINED MODEL ######
# plast_rule = torch.load("rule34.pt")
# plast_rule.plot()
#####################################################

plast_rule.train(A_oja, dev, n_meta_it)
plast_rule.plot()

print("Total time in s: " + str(np.round(time.time() - start,2)) + "s")

Total time in s: 0.0s
