In [1]:
import jax
import jax.numpy as jnp
from PIL.ImageChops import offset
from jax import grad, jit, vmap
from jax import random

import matplotlib.pyplot as plt

import numpy as np
from qtconsole.mainwindow import background
from scipy.stats import alpha
from zmq import WSS_KEY_PEM

from Utils.models import *
from Utils.simulation import *
from Utils.plot_utils import *

import time
from copy import copy


%load_ext autoreload
%autoreload 2

In [2]:

jax.config.update("jax_platform_name", "cpu")  # Force CPU usage
#
print(jax.devices())  # Should show only CpuDevice


[CpuDevice(id=0)]


In [None]:
params_dict = {
    "dataset_parameters": {
        "n_samples": 200
    },
    "network_parameters": {
        "input_size": 64,
        "hidden_size": 128,
        "output_size": 1,
        "bias": 1,
    },
    "training_parameters": {
        "num_epochs": 100,
        "learning_rate": 0.01
    },
    "simulation_parameters": {
        "mu": 1,
        "sigma": 0.1,
        "theta": 0.02,
        "dt": 0.001,
        "tau": 0.002
    },
    "seed": 42
}

# set the random key
rng = random.PRNGKey(params_dict["seed"])


# Create the model

In [None]:
# mean and sigma of the lognormal distribution from the paramaters
mu_LN = mu_LN_from_params(**params_dict["simulation_parameters"])
sigma_LN = sigma_LN_from_params(**params_dict["simulation_parameters"])

#initialize the 2 layer ELM
rng, net_key = random.split(rng)
params = init_elm(net_key, mu_LN, sigma_LN, **params_dict["network_parameters"])
print(params.keys())

In [None]:
rng, data_key = random.split(rng)

X_train, y_train = create_binary_dataset(data_key,
                                         n_samples=params_dict["dataset_parameters"]["n_samples"],
                                         input_dim=params_dict["network_parameters"]["input_size"])


### Proposed training rule

Here the weights are not trained via gradient, but the means adapt

In [None]:

simulation_parameters = params_dict["simulation_parameters"]
training_parameters = params_dict["training_parameters"]
num_epochs = training_parameters["num_epochs"]
learning_rate = training_parameters["learning_rate"]


sigma = simulation_parameters["sigma"]
theta = simulation_parameters["theta"]
dt = simulation_parameters["dt"]
tau = simulation_parameters["tau"]

mu = jnp.ones_like(params["W_i"])*simulation_parameters["mu"]
rate = 0.0001

In [None]:
loss_list = []
acc_list = []
weight_list = []
mu_list = []

for epoch in range(num_epochs):
    start_time = time.time()
    for x,y in zip(X_train,y_train):

        rng, gou_key = random.split(rng)
        #perturb the weights of W_i
        params['W_i'] = time_evolution_GOU(gou_key, params['W_i'], mu, theta, sigma, tau, dt )

        grads = grad(loss_elm)(params, x, y)

        params['W_i'] -= learning_rate * grads['W_i']
        params['W_o'] -= learning_rate * grads['W_o']
        params['b_i'] -= learning_rate * grads['b_i']
        params['b_o'] -= learning_rate * grads['b_o']

        mu += rate * (params['W_i'] - mu)
        weight_list.append(params['W_i'].flatten())
        mu_list.append(mu.flatten())

    acc_list.append(accuracy_elm(params, X_train, y_train))
    loss_list.append(loss_elm(params, X_train, y_train))


    if epoch%10==0:
        epoch_time = time.time() - start_time
        train_loss = loss_elm(params, X_train, y_train)
        train_acc = accuracy_elm(params, X_train, y_train)
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set loss {}".format(train_loss))
        print("Training set accuracy {}".format(train_acc))

In [None]:
#plot the accuracy and loss

fig, axs = plt.subplots(1,2, figsize=(8,4))
axs[0].plot(acc_list)
axs[0].set_title("Accuracy")
axs[0].set_ylim([0,1])

axs[1].plot(loss_list)
axs[1].set_title("Loss")
plt.show()


In [None]:
plot_weight_dynamics(np.array(weight_list), title="Weight Dynamics", weights_to_show=100, show=True)

In [None]:
plot_weight_dynamics(np.array(mu_list), title="mu Dynamics", weights_to_show=100, show=True)

In [None]:
plt.hist(mu.flatten(), bins=100)
plt.show()

In [None]:
plt.hist(params["W_i"].flatten(), bins=100)
plt.show()

In [None]:
plt.hist(params["W_i"].flatten()-mu.flatten(), bins=100)
plt.show()

### Testing

In [None]:
test_loss_list = []
test_acc_list = []
test_weight_list = []
test_mu_list = []

for epoch in range(num_epochs):
    for x,y in zip(X_train,y_train):

        rng, gou_key = random.split(rng)
        #perturb the weights of W_i
        params['W_i'] = time_evolution_GOU(gou_key, params['W_i'], mu, theta, sigma, tau, dt )

        test_weight_list.append(params['W_i'].flatten())
        test_mu_list.append(mu.flatten())

    test_acc_list.append(accuracy_elm(params, X_train, y_train))
    test_loss_list.append(loss_elm(params, X_train, y_train))


    if epoch%10==0:
        train_loss = loss_elm(params, X_train, y_train)
        train_acc = accuracy_elm(params, X_train, y_train)
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set loss {}".format(train_loss))
        print("Training set accuracy {}".format(train_acc))

In [None]:
tot_loss_list =  loss_list + test_loss_list
tot_acc_list = acc_list + test_acc_list

In [None]:
fig, axs = plt.subplots(1,2, figsize=(8,4))
axs[0].plot(tot_acc_list)
axs[0].set_title("Accuracy")
axs[0].set_ylim([0,1])
axs[0].axvline(x=num_epochs, color='k', linestyle='--')

axs[1].plot(tot_loss_list)
axs[1].set_title("Loss")
axs[1].axvline(x=num_epochs, color='k', linestyle='--')

plt.savefig("weak_weights_adapt.png")
plt.show()

