In [None]:
import numpy as np
import scipy.linalg
import scipy.special
from scipy.stats import ortho_group
from scipy import optimize
import copy

import matplotlib.pyplot as plt
from matplotlib import colors

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import pandas as pd
from state_evolution.algorithms.state_evolution import StateEvolution # Standard SP iteration
from state_evolution.data_models.custom import Custom # Custom data model. You input the covariances
from state_evolution.experiments.learning_curve import CustomExperiment
from state_evolution.data_models.custom import CustomSpectra

In [None]:
def give_raw_data(n,d,F_t,theta):
    np.random.seed(n)
    inputs=np.random.randn(d,n)
    hidden=np.tanh(F_t@inputs/np.sqrt(d))
    #hidden=F_t@inputs/np.sqrt(d)
    labels=theta@hidden
    return inputs.T, labels

def give_dataset(n_test,n_train,d,F_t,theta):
    x_test,y_test=give_raw_data(n_test,d,F_t,theta)
    x_train,y_train=give_raw_data(n_train,d,F_t,theta)
    
    training_data=CustomDataset(x_train,y_train)
    testing_data=CustomDataset(x_test,y_test)
    
    return DataLoader(training_data, batch_size=int(n_train), shuffle=True), DataLoader(testing_data, batch_size=int(n_test), shuffle=True)

class CustomDataset(Dataset):
    def __init__(self, inputs, labels):
        self.labels = labels
        self.inputs = inputs

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        Input = self.inputs[idx]
        Label = self.labels[idx]
        return Input, Label

In [None]:
COEFICIENTS = {'relu': (1/np.sqrt(2*np.pi), 0.5, np.sqrt((np.pi-2)/(4*np.pi))), 
               'erf': (0, 2/np.sqrt(3*np.pi), 0.200364), 'tanh': (0, 0.605706, 0.165576),
               'sign': (0, np.sqrt(2/np.pi), np.sqrt(1-2/np.pi))}

# Coefficients
_, kappa1_teacher, kappastar_teacher = COEFICIENTS['tanh']
_, kappa1_student, kappastar_student = COEFICIENTS['tanh']

In [None]:
def g3m_input_matrices_simple(d,F_s,F_t):
    # Covariances
    Psi = (kappa1_teacher**2 * F_t @ F_t.T)/d + kappastar_teacher**2 * np.identity(d)
    Omega = (kappa1_student**2 * F_s @ F_s.T)/d + kappastar_student**2 * np.identity(d)
    Phi = (kappa1_teacher * kappa1_student * F_t @ F_s.T)/d + kappastar_student*kappastar_teacher * np.identity(d)
    return Psi,Omega,Phi


def g3m_input_matrices_MC(d,n,F_s,F_t):
    x=np.random.randn(d,n)
    u=np.tanh(F_t@x/np.sqrt(d))
    v=np.tanh(F_s@x/np.sqrt(d))
    Psi=1/n*(u@u.T)
    Omega=1/n*(v@v.T)
    Phi=1/n*(u@v.T)
    return Psi,Omega,Phi


def g3m_prediction(d, n_list, reg, F_s, F_t, theta, method):
    n_covariance=10000
    if method=='MC':
        Psi,Omega,Phi=g3m_input_matrices_MC(d,n_covariance, F_s, F_t)
    if method=='simple':
        Psi,Omega,Phi=g3m_input_matrices_simple(d,F_s, F_t)
    
    data_model = Custom(teacher_teacher_cov = Psi, 
                        student_student_cov = Omega, 
                        teacher_student_cov = Phi,
                        teacher_weights = theta)
    
    my_experiment = CustomExperiment(task = 'ridge_regression', 
                                     regularisation = reg, 
                                     data_model = data_model, 
                                     initialisation='uninformed', 
                                     tolerance = 1e-15,
                                     damping = 0.1, 
                                     verbose = False, 
                                     max_steps = 5000)
    my_experiment.learning_curve(alphas = n_list/d)
    a=my_experiment.get_curve()
    Eg_g3m=np.array(a['test_error'])
    Et_g3m=np.array(a['train_loss'])
    L_g3m=np.array(a['loss'])
    return Eg_g3m,Et_g3m,L_g3m

In [None]:
def train_loop(dataloader, model, criterion, optimizer, d, regularisation, verbose=True, info='train_loss',train=True):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = y.size(dim=0)*criterion(pred[:,0]/np.sqrt(d), y/np.sqrt(d))
        for i,param in enumerate(model.parameters()):
            if i == 1:            #only regularize the last layer
                loss += regularisation *torch.norm(param)**2
        if train:
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if verbose:
            if info=='train_loss':
                return criterion(pred[:,0]/np.sqrt(d), y/np.sqrt(d)).item()
            if info=='loss':
                return loss.item()/d


def test_loop(dataloader, model, criterion, optimizer, d, verbose=True):
    num_batches = len(dataloader)
    test_loss = 0
    test_loss_bis = 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += criterion(pred[:,0]/np.sqrt(d), y/np.sqrt(d)).item()
            test_loss_bis += np.mean(((pred[:,0]/np.sqrt(d) - y/np.sqrt(d))**2).numpy())
            if np.abs(test_loss - test_loss_bis) > 1e-10:
                print(test_loss - test_loss_bis)

    test_loss /= num_batches
    if verbose:
        print(f"generalisation error: {test_loss:>8f} \n")
    return test_loss

In [None]:
def train(d,Lambda,alpha,init_type,epochs,freeze_1st_layer,F_t,theta,F_random,theta_random):    
    Loss_decay=[]
    Eg=[]
    Et=[]
    Loss=[]
    W=[]
    Fs=[]
    n_array=(d*np.array(alpha)).astype(int)

    n_test=50000 #number of samples used to evaluate the generalisation error
    for reg in Lambda:
        fs=[]
        w=[]
        loss_decay=[]
        eg,et,l=[],[],[]
        for j,n_train in enumerate(n_array):

            #Building the network
            model=nn.Sequential(nn.Linear(d, d, bias=False),
                                nn.Tanh(),
                                nn.Linear(d, 1, bias=False))

            #Initialisation
            if init_type=='random':
                model[0].weight.data=torch.from_numpy(F_random/np.sqrt(d))
                model[2].weight.data=torch.from_numpy(np.array([theta_random]))
            if init_type=='planted':
                model[0].weight.data=torch.from_numpy(F_t/np.sqrt(d))
                model[2].weight.data=torch.from_numpy(np.array([theta]))

            model = model.double()

            # Define the loss
            criterion = nn.MSELoss()

            # Optimizers require the parameters to optimize and a learning rate
            optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

            #Freeze first layer
            if freeze_1st_layer:
                for i,param in enumerate(model.parameters()):
                    if i==0:
                        param.requires_grad = False

            loss=[]
            train_dataloader,test_dataloader = give_dataset(n_test,n_train,d,F_t,theta)
            for t in range(epochs):
                loss.append(train_loop(train_dataloader,model,criterion,optimizer,d,reg,verbose=True,info='loss',train=True))
                if t == epochs-1:
                    eg.append(test_loop(test_dataloader,model,criterion,optimizer,d,verbose=False))
                    et.append(train_loop(train_dataloader,model,criterion,optimizer,d,reg,verbose=True,train=True))
                    l.append(train_loop(train_dataloader,model,criterion,optimizer,d,reg,verbose=True,info='loss',train=True))
                    fs.append(copy.deepcopy((model[0].weight.data.numpy())*np.sqrt(d)))
                    w.append(copy.deepcopy(model[2].weight.data.numpy()))
            loss_decay.append(loss)
            print(f"lambda={reg}, alpha={alpha[j]} is done")
        Eg.append(eg)
        Et.append(et)
        Loss.append(l)
        W.append(w)
        Fs.append(fs)
        Loss_decay.append(loss_decay)
    return Eg,Et,Loss,W,Fs,Loss_decay

In [None]:
cdict = {'red':   ((0.0,  1.0, 1.0),
                   (0.5,  0.6, 0.6),
                   (1.0,  0.5, 0.5)),

         'green': ((0.0,  0.3, 0.3),
                   (0.5,  0.4, 0.4),
                   (1.0,  0.5, 0.5)),

         'blue':  ((0.0,  0.3, 0.3),
                   (0.5,  0.6, 0.6),
                   (1.0,  1.0, 1.0))}