In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

import matplotlib.pyplot as plt

import numpy as np

from Utils.models import *
from Utils.simulation import *
from Utils.plot_utils import *

import time
from copy import copy

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


First we define some basics

In [3]:
import json

# Create a dictionary to store the parameters
params_dict = {
    "dataset_parameters": {
        "n_samples": 200
    },
    "network_parameters": {
        "input_size": 2**6,
        "hidden_size": 2**7,
        "output_size": 1,
        "bias": 1,
    },
    "training_parameters": {
        "num_epochs": 50,
        "learning_rate": 0.1
    },
    "simulation_parameters": {
        "mu": 1,
        "sigma": 0.1,
        "theta": 0.5,
        "dt": 0.001
    },
    "seed": 42
}

# Save the dictionary as a JSON file
# with open('simulation_parameters.json', 'w') as json_file:
#     json.dump(params_dict, json_file, indent=4)
rng = random.key(params_dict["seed"])

In [4]:
mu_LN = mu_LN_from_params(**params_dict["simulation_parameters"])
sigma_LN = sigma_LN_from_params(**params_dict["simulation_parameters"])

sigma_LN

Array(0.10025112, dtype=float32, weak_type=True)

In [5]:
rng, net_key = random.split(rng)
params = init_mlp(net_key, mu_LN, sigma_LN, **params_dict["network_parameters"])
W_h_init = copy(params['W_h'])

We create a stupid dataset made of binary arrays to be matched with a binary output.

In [6]:
rng, data_key = random.split(rng)

X_train, y_train = create_binary_dataset(data_key, n_samples=params_dict["dataset_parameters"]["n_samples"], input_dim=params_dict["network_parameters"]["input_size"])


In [8]:
print(X_train.shape)
print(forward_mlp(params, X_train).shape)

(200, 64)
(200, 1)


In [None]:


loss_list = []
acc_list = []

train_weights_lists = {'W_i': [], 'W_h': [], 'W_o': []}

training_parameters = params_dict["training_parameters"]
num_epochs = training_parameters["num_epochs"]
learning_rate = training_parameters["learning_rate"]
simulation_parameters = params_dict["simulation_parameters"]

for epoch in range(num_epochs):
    start_time = time.time()
    for x,y in zip(X_train,y_train):

        rng, gou_key = random.split(rng)

        grads = grad(loss_mlp)(params, x, y)
        params['W_i'] -= learning_rate*grads['W_i']
        params['W_h'] -= learning_rate*grads['W_h']-perturb_GOU(gou_key, params['W_h'],  **simulation_parameters )
        params['W_o'] -= learning_rate*grads['W_o']



        # params['W_h'] += perturb_GOU(gou_key, params['W_h'],  **simulation_parameters )
        #

        train_weights_lists['W_i'].append(params['W_i'])
        train_weights_lists['W_h'].append(params['W_h'])
        train_weights_lists['W_o'].append(params['W_o'])

    acc_list.append(accuracy_mlp(params, X_train, y_train))
    loss_list.append(loss_mlp(params, X_train, y_train))


    if epoch%10==0:
        epoch_time = time.time() - start_time
        train_loss = loss_mlp(params, X_train, y_train)
        train_acc = accuracy_mlp(params, X_train, y_train)
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set loss {}".format(train_loss))
        print("Training set accuracy {}".format(train_acc))


for epoch in range(num_epochs):
    start_time = time.time()
    for x,y in zip(X_train,y_train):


        rng, gou_key = random.split(rng)
        params['W_h'] += perturb_GOU(gou_key, params['W_h'],  **simulation_parameters )

    acc_list.append(accuracy_mlp(params, X_train, y_train))
    loss_list.append(loss_mlp(params, X_train, y_train))


    if epoch%10==0:
        epoch_time = time.time() - start_time
        train_loss = loss_mlp(params, X_train, y_train)
        train_acc = accuracy_mlp(params, X_train, y_train)
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set loss {}".format(train_loss))
        print("Training set accuracy {}".format(train_acc))

In [None]:
train_weights_lists['W_h'] = np.array(train_weights_lists['W_h']).reshape(num_epochs*params_dict["dataset_parameters"]["n_samples"], -1)
train_weights_lists['W_i'] = np.array(train_weights_lists['W_i']).reshape(num_epochs*params_dict["dataset_parameters"]["n_samples"], -1)
train_weights_lists['W_o'] = np.array(train_weights_lists['W_o']).reshape(num_epochs*params_dict["dataset_parameters"]["n_samples"], -1)


In [None]:

#plot the loss and accuracy in 2 subplots

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(loss_list)
plt.ylim([0, 3])
plt.title('Loss')
plt.axvline(x=num_epochs, color='r', linestyle='--')
plt.subplot(1, 2, 2)
plt.plot(acc_list)
plt.ylim([0, 1.1])
plt.title('Accuracy')
#draw a vertical line at the end of the training
plt.axvline(x=num_epochs, color='r', linestyle='--')

plt.savefig("loss_acc_test.png")
plt.show()

In [None]:

#plot weight dynamics
name = "W_h"

plot_weight_dynamics(train_weights_lists[name], "Weight Dynamics", weights_to_show=600)
#plot_weight_dynamics(train_weights_lists[name], "Weight Dynamics", weights_to_show=500, log=True)


In [None]:
plot_weights(params['W_h'], "Final Weights")
# initial weights
plot_weights(W_h_init, "Initial Weights")

delta = params['W_h'] - W_h_init

plot_weights(delta, "Weight Variation")

# EIGENVLAUES
---------------------------------------------

In [None]:
plot_eigenvalues(params['W_h'], W_h_init, bias = params_dict["network_parameters"]["bias"], log=False)
plot_eigenvalues(params['W_h'], W_h_init, bias = params_dict["network_parameters"]["bias"], log=True)

In [None]:
# check the initial weights dynamics with no learning

rng, pert_key = random.split(rng)
W_h_perturb = copy(W_h_init)

perturbed_weight_list = simulate_perturbation_only(pert_key,
                                         W_h_perturb,
                                         n_steps=num_epochs*len(X_train),
                                        **simulation_parameters)


final_weights = perturbed_weight_list[-1]

In [None]:
plot_weights(final_weights, "Final Perturbed Weights")

plot_weights(W_h_init, "Initial Weights")

In [None]:
plot_weights(final_weights - W_h_init, "Weight Variation")

In [None]:
plot_weight_dynamics(perturbed_weight_list.reshape(perturbed_weight_list.shape[0], -1), "Weight Dynamics (only GOU)")

In [None]:
W_h_init.mean()

In [None]:
final_weights.mean()

In [None]:
W_h_init.var()

In [None]:
final_weights.var()


In [None]:
print(X_train.shape)