In [None]:
import jax.numpy as jnp
from PIL.ImageChops import offset
from jax import grad, jit, vmap
from jax import random

import matplotlib.pyplot as plt

import numpy as np
from qtconsole.mainwindow import background
from scipy.stats import alpha

from Utils.models import *
from Utils.simulation import *
from Utils.plot_utils import *

import time
from copy import copy

%load_ext autoreload
%autoreload 2

In [None]:
2**7

In [None]:
params_dict = {
    "dataset_parameters": {
        "n_samples": 200 
    },
    "network_parameters": {
        "input_size": 64,
        "hidden_size": 128,
        "output_size": 1,
        "bias": 1,
    },
    "training_parameters": {
        "num_epochs": 100,
        "learning_rate": 0.01
    },
    "simulation_parameters": {
        "mu": 1,
        "sigma": 0.1,
        "theta": 0.02,
        "dt": 0.001,
        "tau": 0.005
    },
    "seed": 42
}

In [None]:
rng = random.key(params_dict["seed"])

In [None]:
mu_LN = mu_LN_from_params(**params_dict["simulation_parameters"])
sigma_LN = sigma_LN_from_params(**params_dict["simulation_parameters"])

rng, net_key = random.split(rng)
params = init_elm(net_key, mu_LN, sigma_LN, **params_dict["network_parameters"])

In [None]:
#histogram of the weights of W_i
plt.hist(params['W_i'].flatten(), bins=100)
plt.show()


In [None]:
# simulate perturbations of the weights of W_i
simulation_parameters = params_dict["simulation_parameters"]
rng, sim_key = random.split(rng)
weight_list = simulate_perturbation_only(sim_key, params['W_i'].flatten(), 1000, simulation_parameters['mu'], simulation_parameters['theta'], simulation_parameters['sigma'], simulation_parameters['dt'])


In [None]:
plt.plot(weight_list[:, :100])

In [None]:
#histogram of final weights
plt.hist(weight_list[-1], bins=100, color='r')
plt.hist(params['W_i'].flatten(), bins=100)
plt.show()



In [None]:
#simulate perturbations of the weights of W_i
rng, sim_key = random.split(rng)
W_f = time_evolution_GOU(sim_key, params['W_i'], **simulation_parameters)



In [None]:
plt.hist(weight_list[-1], bins=100, color='r')
plt.hist(W_f.flatten(), bins=100)
plt.show()

In [None]:
rng, data_key = random.split(rng)

X_train, y_train = create_binary_dataset(data_key, 
                                         n_samples=params_dict["dataset_parameters"]["n_samples"],
                                         input_dim=params_dict["network_parameters"]["input_size"])

In [None]:

loss_list = []
acc_list = []

training_parameters = params_dict["training_parameters"]
num_epochs = training_parameters["num_epochs"]
learning_rate = training_parameters["learning_rate"]
simulation_parameters = params_dict["simulation_parameters"]

weight_list = []
rng = random.key(params_dict["seed"])
for epoch in range(num_epochs):
    start_time = time.time()
    for x,y in zip(X_train,y_train):

        rng, gou_key = random.split(rng)
        #perturb the weights of W_i
        params['W_i'] = time_evolution_GOU(gou_key, params['W_i'], **simulation_parameters)
        #params['W_i'] += perturb_GOU(gou_key, params['W_i'], simulation_parameters['mu'], simulation_parameters['theta'], simulation_parameters['sigma'], simulation_parameters['dt'])
        
        grads = grad(loss_elm)(params, x, y)
        params['W_i'] -= learning_rate * grads['W_i']
        params['W_o'] -= learning_rate * grads['W_o']
        params['b_i'] -= learning_rate * grads['b_i']
        params['b_o'] -= learning_rate * grads['b_o']
            
        weight_list.append(params['W_i'].flatten())
        
    acc_list.append(accuracy_elm(params, X_train, y_train))
    loss_list.append(loss_elm(params, X_train, y_train))


    if epoch%10==0:
        epoch_time = time.time() - start_time
        train_loss = loss_elm(params, X_train, y_train)
        train_acc = accuracy_elm(params, X_train, y_train)
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set loss {}".format(train_loss))
        print("Training set accuracy {}".format(train_acc))




In [None]:
weight_list = np.array(weight_list)
plt.hist(weight_list[-1], bins=500)
plt.show()


In [None]:
#make a long figure
plt.figure(figsize=(20,5))

#plot 

# put a  vertical line every 200 steps
for i in range(0, num_epochs*200, 200):
    plt.axvline(x=i, color='r', linestyle='--', alpha = 0.1)
plt.plot(weight_list[:, :100], alpha=0.5, c = 'b')

plt.show()
    



In [None]:
#save the data
np.save("old_results/weight_list_training.npy", weight_list)

In [None]:
weight_list = np.load("old_results/weight_list_training.npy")

fig, axs = plt.subplots(1, 3, figsize=(15, 3), gridspec_kw={'width_ratios': [1, 4, 1]})

# Plot the initial weights
axs[0].hist(weight_list[0], bins=500)
axs[0].set_title("Initial weights")
axs[0].set_xlim(0, 5)
# Plot the weight evolution

for i in np.arange(0, num_epochs, 1):
    axs[1].axvline(x=i*(simulation_parameters['tau'])*200, alpha = 0.1)
times = np.arange(0, num_epochs*200) * simulation_parameters['tau']
axs[1].plot(times, weight_list[:, :20], alpha=0.5, c = 'b')
axs[1].set_title("Weight evolution")
axs[1].set_xlim(0, num_epochs*200*simulation_parameters['tau'])

# Plot the final weights
axs[2].hist(weight_list[-1], bins=500)
axs[2].set_title("Final weights")
axs[2].set_xlim(0, 5)

plt.savefig("weight_evolution.png", dpi=300, bbox_inches='tight')

plt.tight_layout()
plt.show()

In [None]:
times.shape

In [None]:
#plot losses and accuracy next to each other
fig, ax = plt.subplots(1,2, figsize=(10,5))
ax[0].plot(loss_list)
ax[0].set_title("Loss")
ax[1].plot(acc_list)
ax[1].set_title("Accuracy")
plt.show()


In [None]:
    #plot losses and accuracy next to each other
fig, ax = plt.subplots(1,2, figsize=(10,5))
ax[0].plot(loss_list)
ax[0].set_title("Loss")
ax[1].plot(acc_list)
ax[1].set_title("Accuracy")
plt.show()



In [None]:
for epoch in range(num_epochs):
    start_time = time.time()
    for x,y in zip(X_train,y_train):
        #perturb the weights of W_i
        rng, gou_key = random.split(rng)
        params['W_i'] = time_evolution_GOU(gou_key, params['W_i'], **simulation_parameters)

    acc_list.append(accuracy_elm(params, X_train, y_train))
    loss_list.append(loss_elm(params, X_train, y_train))


    if epoch%10==0:
        epoch_time = time.time() - start_time
        train_loss = loss_elm(params, X_train, y_train)
        train_acc = accuracy_elm(params, X_train, y_train)
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set loss {}".format(train_loss))
        print("Training set accuracy {}".format(train_acc))

In [None]:
#plot losses and accuracy next to each other
fig, ax = plt.subplots(1,2, figsize=(10,5))
ax[0].plot(loss_list)
ax[0].set_title("Loss")
ax[1].plot(acc_list)
ax[1].set_title("Accuracy")
#vertical line to show the end of the training
ax[0].axvline(x=num_epochs, color='r', linestyle='--')
ax[1].axvline(x=num_epochs, color='r', linestyle='--')
plt.show()


In [None]:
plt.figure(figsize=(3,3))
epochs_time = np.arange(0, num_epochs*2)*simulation_parameters['tau']*200
plt.plot(epochs_time,acc_list, 'r', label='Accuracy')
#vertical line to show the end of the training
plt.axvline(x=num_epochs*200*simulation_parameters['tau'], color='k', alpha = 0.5)
# write that the vertical line is the end of training
plt.text(num_epochs*187*simulation_parameters['tau'],  0.52,'End of training',rotation=90)
#horizonatal line at 0.5
plt.axhline(y=0.5, color='g', linestyle='--')
#transparent legend
plt.legend(fontsize=9)
plt.xlim(50,130)
plt.savefig("accuracy.png", dpi=300, bbox_inches='tight')

plt.plot()

# DIFFERENT TAUS

In [None]:
rng = random.key(params_dict["seed"])

mu_LN = mu_LN_from_params(**params_dict["simulation_parameters"])
sigma_LN = sigma_LN_from_params(**params_dict["simulation_parameters"])

rng, data_key = random.split(rng)

X_train, y_train = create_binary_dataset(data_key, n_samples=params_dict["dataset_parameters"]["n_samples"], input_dim=params_dict["network_parameters"]["input_size"])

tau_list = jnp.arange(0.00, 0.05, 0.005)
loss_tau = []
acc_tau = []

training_parameters = params_dict["training_parameters"]
num_epochs = training_parameters["num_epochs"]
learning_rate = training_parameters["learning_rate"]
simulation_parameters = params_dict["simulation_parameters"]

print('tau_list', tau_list)

for tau in tau_list:
    print("Tau: ", tau)
    simulation_parameters["tau"] = tau
    rng, net_key = random.split(rng)
    params = init_elm(net_key, mu_LN, sigma_LN, **params_dict["network_parameters"])

    for epoch in range(num_epochs):
        for x,y in zip(X_train,y_train):

            rng, gou_key = random.split(rng)
            #perturb the weights of W_i
            
            if tau != 0:
                params['W_i'] = time_evolution_GOU(gou_key, params['W_i'], **simulation_parameters)
            #params['W_i'] += perturb_GOU(gou_key, params['W_i'], simulation_parameters['mu'], simulation_parameters['theta'], simulation_parameters['sigma'], simulation_parameters['dt'])
            
            grads = grad(loss_elm)(params, x, y)
            params['W_i'] -= learning_rate * grads['W_i']
            params['W_o'] -= learning_rate * grads['W_o']
            params['b_i'] -= learning_rate * grads['b_i']
            params['b_o'] -= learning_rate * grads['b_o']
            
        if epoch%10==0:
            print("Epoch {}. Acc = {}".format(epoch,accuracy_elm(params, X_train, y_train)))          

    print("Loss: ",loss_elm(params, X_train, y_train))
    print("Accuracy: ", accuracy_elm(params, X_train, y_train))
    acc_tau.append(accuracy_elm(params, X_train, y_train))
    loss_tau.append(loss_elm(params, X_train, y_train))
    

In [None]:
#plot losses and accuracy next to each other
fig, ax = plt.subplots(1,2, figsize=(10,5))
ax[0].plot(tau_list, loss_tau, '.')
ax[0].set_title("Loss")
ax[1].plot(tau_list,acc_tau, '.')
ax[1].set_title("Accuracy")
plt.show()


In [None]:
#save loss and accuracy
np.save("old_results/loss_tau.npy", loss_tau)
np.save("old_results/acc_tau.npy", acc_tau)


In [None]:
tau_list = jnp.arange(0.00, 0.05, 0.005)
acc_tau = np.load("old_results/acc_tau_tot.npy")

plt.figure(figsize=(2,2))
plt.plot(tau_list, np.mean(acc_tau, axis=1), 'r', label='Accuracy')
plt.fill_between(tau_list, np.mean(acc_tau, axis=1) - np.std(acc_tau, axis=1),
                   np.mean(acc_tau, axis=1) + np.std(acc_tau, axis=1), alpha=0.3, color = 'r')

plt.xlabel('Tau')
plt.ylim(0.4,1)
#horizonatal line at 0.5
plt.axhline(y=0.5, color='g', linestyle='--')
plt.legend()
plt.savefig("tau_accuracy.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
ax[0].plot(tau_list, np.mean(acc_tau, axis=1), label='Accuracy')
ax[0].fill_between(tau_list, np.mean(acc_tau, axis=1) - np.std(acc_tau, axis=1),
                   np.mean(acc_tau, axis=1) + np.std(acc_tau, axis=1), alpha=0.3)
ax[0].set_xlabel('Tau')
ax[0].set_ylabel('Accuracy')

# TAU EXPLORATION with MLP

In [None]:
tau_list = jnp.arange(0.00, 0.1, 0.005)
loss_tau = []
acc_tau = []

training_parameters = params_dict["training_parameters"]
num_epochs = training_parameters["num_epochs"]
learning_rate = training_parameters["learning_rate"]
simulation_parameters = params_dict["simulation_parameters"]

print('tau_list', tau_list)



In [None]:
rng = random.key(params_dict["seed"])

mu_LN = mu_LN_from_params(**params_dict["simulation_parameters"])
sigma_LN = sigma_LN_from_params(**params_dict["simulation_parameters"])

tau_list = jnp.arange(0.00, 0.1, 0.005)
loss_tau = []
acc_tau = []

training_parameters = params_dict["training_parameters"]
num_epochs = training_parameters["num_epochs"]
learning_rate = training_parameters["learning_rate"]
simulation_parameters = params_dict["simulation_parameters"]

print('tau_list', tau_list)

for tau in tau_list:
    print("Tau: ", tau)
    simulation_parameters["tau"] = tau
    rng, net_key = random.split(rng)
    params = init_mlp(net_key, mu_LN, sigma_LN, **params_dict["network_parameters"])

    for epoch in range(num_epochs):
        start_time = time.time()
        for x,y in zip(X_train,y_train):

            rng, gou_key = random.split(rng)
            #perturb the weights of W_i
            
            if tau != 0:
                params['W_h'] = time_evolution_GOU(gou_key, params['W_h'], **simulation_parameters)
            #params['W_i'] += perturb_GOU(gou_key, params['W_i'], simulation_parameters['mu'], simulation_parameters['theta'], simulation_parameters['sigma'], simulation_parameters['dt'])
            
            grads = grad(loss_mlp)(params, x, y)
            params['W_i'] -= learning_rate * grads['W_i']
            params['W_h'] -= learning_rate * grads['W_h']
            params['W_o'] -= learning_rate * grads['W_o']
                
        if epoch%10==0:
            epoch_time = time.time() - start_time
            print("Epoch {} in {:0.2f} sec. Acc = {}".format(epoch, epoch_time,accuracy_mlp(params, X_train, y_train)))              

    print("Loss: ",loss_mlp(params, X_train, y_train))
    print("Accuracy: ", accuracy_mlp(params, X_train, y_train))
    acc_tau.append(accuracy_mlp(params, X_train, y_train))
    loss_tau.append(loss_mlp(params, X_train, y_train))



In [None]:
loss_list = []
acc_list = []

params = init_elm(net_key, mu_LN, sigma_LN, **params_dict["network_parameters"])

rng, data_key = random.split(rng)

X_train, y_train = create_binary_dataset(data_key, 
                                         n_samples=params_dict["dataset_parameters"]["n_samples"],
                                         input_dim=params_dict["network_parameters"]["input_size"])

training_parameters = params_dict["training_parameters"]
num_epochs = training_parameters["num_epochs"]
learning_rate = training_parameters["learning_rate"]
simulation_parameters = params_dict["simulation_parameters"]

weight_list = []
rng = random.key(params_dict["seed"])
for epoch in range(23):
    start_time = time.time()
    for x,y in zip(X_train,y_train):

        rng, gou_key = random.split(rng)
        #perturb the weights of W_i
        n_steps = int(simulation_parameters['tau'] /simulation_parameters['dt'])

        for _ in range(n_steps):
            gou_key, sim_key = random.split(gou_key)
            params['W_i'] += perturb_GOU(sim_key, 
                                         params['W_i'],  
                                         simulation_parameters['mu'], 
                                         simulation_parameters['theta'],  
                                         simulation_parameters['sigma'], 
                                         simulation_parameters['dt'])
            if epoch >= 20:
                weight_list.append(params['W_i'].flatten())
        
        grads = grad(loss_elm)(params, x, y)
        params['W_i'] -= learning_rate * grads['W_i']
        params['W_o'] -= learning_rate * grads['W_o']
        params['b_i'] -= learning_rate * grads['b_i']
        params['b_o'] -= learning_rate * grads['b_o']
            
        if epoch >= 20:
                weight_list.append(params['W_i'].flatten())
        
    if epoch >= 20:
                acc_list.append(accuracy_elm(params, X_train, y_train))
                loss_list.append(loss_elm(params, X_train, y_train))


    if epoch%10==0:
        epoch_time = time.time() - start_time
        train_loss = loss_elm(params, X_train, y_train)
        train_acc = accuracy_elm(params, X_train, y_train)
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set loss {}".format(train_loss))
        print("Training set accuracy {}".format(train_acc))



In [None]:
weight_list = np.array(weight_list)

weight_list.shape

In [None]:
fig = plt.figure(figsize=(5,2))
off_set = 20
n_steps = int(simulation_parameters['tau']  /simulation_parameters['dt'])

for i in range(0, weight_list.shape[0] + 1, n_steps + 1):
    plt.axvline(x=i*simulation_parameters['tau'] + off_set, color='g', linestyle='dotted', alpha = 0.1)

#plot a vertical line every epoch
for i in range(0, 100 , 1):
    plt.axvline(x=i)

times = np.arange(0, weight_list.shape[0]) * simulation_parameters['dt']
plt.plot(times + off_set, weight_list[:, :20], alpha=0.5, c = 'b')
plt.tight_layout()
# plt.xlim(off_set, off_set+weight_list.shape[0]*simulation_parameters['dt'])
plt.xlim(20, 22)
plt.savefig("weight_evolution_zoom.png", dpi=300, bbox_inches='tight')
plt.show()


In [None]:
weight_list.shape


In [None]:
7600/2

In [None]:
3800/5

In [None]:
acc_list