In [1]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

import numpy as np
import matplotlib.pyplot as plt
import pickle

import os

import sys

module_path = os.path.abspath(os.path.join('../../'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from mc.mc.network import Network

from mc.tests import test_model_yinyang_integrator_spikes as net_model

from test_tasks.utils import plot_spike_times, calc_loss_interp

from test_tasks.yinyang_integrator_spikes.utils import yinyang_train_dat

In [2]:
col_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']

import ipdb

In [3]:
# network parameters
N_IN = 4
N_HIDDEN = [200,10]
N_OUT = 3

N_BATCH = 128
N_BATCH_VAL = 32

DT = 0.2

In [4]:
# simulation run parameters
T = 60000#0#0#0
NT = int(T / DT)
T = NT * DT

T_SKIP_BATCH_PLAST = 150

T_VAL = 30000
NT_VAL = int(T_VAL / DT)
T_VAL = NT_VAL * DT

T_SHOW_PATTERNS = 1500#800
NT_SHOW_PATTERNS = int(T_SHOW_PATTERNS / DT)
T_SHOW_PATTERNS = NT_SHOW_PATTERNS * DT

T_IND_UPDATE_PATTERNS = np.arange(NT)[::NT_SHOW_PATTERNS]
N_UPDATE_PATTERNS = T_IND_UPDATE_PATTERNS.shape[0]

T_IND_UPDATE_PATTERNS_VAL = np.arange(NT_VAL)[::NT_SHOW_PATTERNS]
N_UPDATE_PATTERNS_VAL = T_IND_UPDATE_PATTERNS_VAL.shape[0]

NT_SKIP_REC = 5000#0#50000#0
T_IND_REC = np.arange(NT)[::NT_SKIP_REC]
N_REC = T_IND_REC.shape[0]

NT_SKIP_REC_VAL = 300
T_IND_REC_VAL = np.arange(NT_VAL)[::NT_SKIP_REC_VAL]
N_REC_VAL = T_IND_REC_VAL.shape[0]

N_VAL_RUNS = 10
T_IND_VAL_RUNS = np.linspace(0., NT - 1, N_VAL_RUNS).astype("int")

In [5]:
# generate some training data
# random voltage values -> if smaller zero, no output
R_MAX_OUT = 1.#0.25

INP_SET, OUT_SET = yinyang_train_dat(N_UPDATE_PATTERNS * N_BATCH)
OUT_SET *= R_MAX_OUT

INPUT_DATA = np.reshape(INP_SET, (N_UPDATE_PATTERNS, N_BATCH, N_IN))
OUTPUT_DATA = np.reshape(OUT_SET, (N_UPDATE_PATTERNS, N_BATCH, N_OUT))

INPUT_DATA_FLAT = INPUT_DATA.flatten()

OUTPUT_DATA_FLAT = OUTPUT_DATA.flatten()

In [6]:
# generate validation data

INP_SET_VAL, OUT_SET_VAL = yinyang_train_dat(N_UPDATE_PATTERNS_VAL * N_BATCH_VAL)

OUT_SET_VAL *= R_MAX_OUT

INPUT_DATA_VAL = np.reshape(INP_SET_VAL, (N_UPDATE_PATTERNS_VAL, N_BATCH_VAL, N_IN))
OUTPUT_DATA_VAL = np.reshape(OUT_SET_VAL, (N_UPDATE_PATTERNS_VAL, N_BATCH_VAL, N_OUT))

INPUT_DATA_VAL_FLAT = INPUT_DATA_VAL.flatten()

OUTPUT_DATA_VAL_FLAT = OUTPUT_DATA_VAL.flatten()

NEUR_READOUT_VAL = [("neur_input_input_pop", "r", T_IND_REC_VAL),
                    ("neur_input_input_pop", "ca", T_IND_REC_VAL),
                    #
                    ("neur_hidden0_pyr_pop", "vb", T_IND_REC_VAL),
                    ("neur_hidden0_pyr_pop", "vbEff", T_IND_REC_VAL),
                    ("neur_hidden0_pyr_pop", "va", T_IND_REC_VAL),
                    ("neur_hidden0_pyr_pop", "u", T_IND_REC_VAL),
                    ("neur_hidden0_pyr_pop", "rEff", T_IND_REC_VAL),
                    ("neur_hidden0_pyr_pop", "ca", T_IND_REC_VAL),
                    ("neur_hidden0_pyr_pop", "b", T_IND_REC_VAL),
                    #
                    ("neur_hidden0_int_pop", "v", T_IND_REC_VAL),
                    ("neur_hidden0_int_pop", "vEff", T_IND_REC_VAL),
                    ("neur_hidden0_int_pop", "va", T_IND_REC_VAL),
                    ("neur_hidden0_int_pop", "u", T_IND_REC_VAL),
                    ("neur_hidden0_int_pop", "rEff", T_IND_REC_VAL),
                    ("neur_hidden0_int_pop", "ca", T_IND_REC_VAL),
                    ("neur_hidden0_pyr_pop", "b", T_IND_REC_VAL),
                    #
                    ("neur_output_output_pop", "vb", T_IND_REC_VAL),
                    ("neur_output_output_pop", "vbEff", T_IND_REC_VAL),
                    ("neur_output_output_pop", "va", T_IND_REC_VAL),
                    ("neur_output_output_pop", "vnudge", T_IND_REC_VAL),
                    ("neur_output_output_pop", "u", T_IND_REC_VAL),
                    ("neur_output_output_pop", "rEff", T_IND_REC_VAL),
                    ("neur_output_output_pop", "ca", T_IND_REC_VAL),
                    ("neur_output_output_pop", "b", T_IND_REC_VAL)]
'''
                    ("neur_hidden1_pyr_pop", "vb", T_IND_REC_VAL),
                    ("neur_hidden1_pyr_pop", "vbEff", T_IND_REC_VAL),
                    ("neur_hidden1_pyr_pop", "va", T_IND_REC_VAL),
                    ("neur_hidden1_pyr_pop", "u", T_IND_REC_VAL),
                    ("neur_hidden1_pyr_pop", "rEff", T_IND_REC_VAL),
                    ("neur_hidden1_pyr_pop", "ca", T_IND_REC_VAL),
                    #
                    ("neur_hidden1_int_pop", "v", T_IND_REC_VAL),
                    ("neur_hidden1_int_pop", "vEff", T_IND_REC_VAL),
                    ("neur_hidden1_int_pop", "u", T_IND_REC_VAL),
                    ("neur_hidden1_int_pop", "rEff", T_IND_REC_VAL),
                    ("neur_hidden1_int_pop", "ca", T_IND_REC_VAL),
                    '''

DICT_DATA_VALIDATION = {"T": NT_VAL,
                        "t_sign": T_IND_UPDATE_PATTERNS_VAL,
                        "ext_data_input": INPUT_DATA_VAL_FLAT,
                        "ext_data_pop_vars": [
                            (np.zeros((1, N_BATCH_VAL, N_OUT)),
                             np.array([0.]).astype("int"),
                             "neur_output_output_pop", "gnudge")],
                        "readout_neur_pop_vars": NEUR_READOUT_VAL}

# copy it for each validation
DATA_VALIDATION = [DICT_DATA_VALIDATION] * N_VAL_RUNS

NEUR_POPS_SPIKE_REC_VAL = []
#'''
#["neur_input_input_pop",
#                           "neur_hidden0_pyr_pop",
#                           "neur_hidden0_int_pop",
#                           "neur_output_output_pop"]'''

In [7]:
# recording settings
NEUR_POPS_SPIKE_REC = []
'''["neur_input_input_pop",
                       "neur_hidden0_pyr_pop",
                       "neur_hidden0_int_pop",
                       "neur_output_output_pop"]'''

NEUR_READOUT = [("neur_input_input_pop", "r", T_IND_REC),
                    ("neur_input_input_pop", "ca", T_IND_REC),
                    #
                    ("neur_hidden0_pyr_pop", "vb", T_IND_REC),
                    ("neur_hidden0_pyr_pop", "vbEff", T_IND_REC),
                    ("neur_hidden0_pyr_pop", "va", T_IND_REC),
                    ("neur_hidden0_pyr_pop", "va_int", T_IND_REC),
                    ("neur_hidden0_pyr_pop", "va_exc", T_IND_REC),
                    ("neur_hidden0_pyr_pop", "u", T_IND_REC),
                    ("neur_hidden0_pyr_pop", "rEff", T_IND_REC),
                    ("neur_hidden0_pyr_pop", "ca", T_IND_REC),
                    ("neur_hidden0_pyr_pop", "b", T_IND_REC),
                    #
                    ("neur_hidden0_int_pop", "v", T_IND_REC),
                    ("neur_hidden0_int_pop", "vEff", T_IND_REC),
                    ("neur_hidden0_int_pop", "va", T_IND_REC),
                    ("neur_hidden0_int_pop", "u", T_IND_REC),
                    ("neur_hidden0_int_pop", "rEff", T_IND_REC),
                    ("neur_hidden0_int_pop", "ca", T_IND_REC),
                    ("neur_hidden0_int_pop", "b", T_IND_REC),
                    #
                    ("neur_output_output_pop", "vb", T_IND_REC),
                    ("neur_output_output_pop", "vbEff", T_IND_REC),
                    ("neur_output_output_pop", "vnudge", T_IND_REC),
                    ("neur_output_output_pop", "va", T_IND_REC),
                    ("neur_output_output_pop", "u", T_IND_REC),
                    ("neur_output_output_pop", "rEff", T_IND_REC),
                    ("neur_output_output_pop", "ca", T_IND_REC),
                    ("neur_output_output_pop", "b", T_IND_REC)]
'''                   #
                    ("neur_hidden1_pyr_pop", "vb", T_IND_REC),
                    ("neur_hidden1_pyr_pop", "vbEff", T_IND_REC),
                    ("neur_hidden1_pyr_pop", "va", T_IND_REC),
                    ("neur_hidden1_pyr_pop", "va_int", T_IND_REC),
                    ("neur_hidden1_pyr_pop", "va_exc", T_IND_REC),
                    ("neur_hidden1_pyr_pop", "u", T_IND_REC),
                    ("neur_hidden1_pyr_pop", "rEff", T_IND_REC),
                    ("neur_hidden1_pyr_pop", "ca", T_IND_REC),
                    #
                    ("neur_hidden1_int_pop", "v", T_IND_REC),
                    ("neur_hidden1_int_pop", "vEff", T_IND_REC),
                    ("neur_hidden1_int_pop", "u", T_IND_REC),
                    ("neur_hidden1_int_pop", "rEff", T_IND_REC),
                    ("neur_hidden1_int_pop", "ca", T_IND_REC),
                    #'''

SYN_READOUT = [("syn_input_input_pop_to_hidden0_pyr_pop",
                "g", T_IND_REC),
               ("syn_hidden0_pyr_pop_to_int_pop",
                "g", T_IND_REC),
               ("syn_hidden0_int_pop_to_pyr_pop",
                "g", T_IND_REC),
               ("syn_hidden0_pyr_pop_to_output_output_pop",
                "g", T_IND_REC),
               ("syn_output_output_pop_to_hidden0_pyr_pop",
                "g", T_IND_REC)]
               #("syn_hidden0_pyr_pop_to_hidden1_pyr_pop",
               # "g", T_IND_REC),
               #("syn_hidden1_pyr_pop_to_int_pop",
               # "g", T_IND_REC),
               #("syn_hidden1_int_pop_to_pyr_pop",
               # "g", T_IND_REC),
               #("syn_hidden1_pyr_pop_to_output_output_pop",
               # "g", T_IND_REC),
               #("syn_output_output_pop_to_hidden1_pyr_pop",
               # "g", T_IND_REC)]

In [8]:
# Initialize network
net = Network("testnet",
              net_model,
              N_IN, N_HIDDEN, N_OUT,
              N_UPDATE_PATTERNS,
              NT,
              NT_VAL,
              dt=DT,
              spike_rec_pops=NEUR_POPS_SPIKE_REC,
              spike_rec_pops_val=NEUR_POPS_SPIKE_REC_VAL,
              plastic=True,
              t_inp_static_max=N_UPDATE_PATTERNS_VAL,
              n_batches=N_BATCH,
              n_batches_val=N_BATCH_VAL)

IndexError: vector::_M_range_check: __n (which is 9) >= this->size() (which is 9)

In [None]:
# match fb with ffwd
'''
net.syn_pops["syn_hidden0_pyr_pop_to_output_output_pop"].pull_var_from_device("g")
w_pp = net.syn_pops["syn_hidden0_pyr_pop_to_output_output_pop"].vars["g"].view
w_pp = np.reshape(w_pp, (N_HIDDEN[0], N_OUT))
w_ppb = (w_pp.T).flatten()


net.syn_pops["syn_output_output_pop_to_hidden0_pyr_pop"].vars["g"].view[:] = w_ppb
net.syn_pops["syn_output_output_pop_to_hidden0_pyr_pop"].push_var_to_device("g")
'''

# set the IP and PI weights into the self-predicting state.
net.init_self_pred_state()

In [None]:
(results_neur, results_syn,
 results_spike, results_validation) = net.run_sim(NT,
                                                  T_IND_UPDATE_PATTERNS,
                                                  INPUT_DATA_FLAT, OUTPUT_DATA_FLAT,
                                                  None,
                                                  NEUR_READOUT,
                                                  SYN_READOUT,
                                                  T_IND_VAL_RUNS,
                                                  DATA_VALIDATION,
                                                  T_skip_batch_plast=T_SKIP_BATCH_PLAST,
                                                  align_fb=True,
                                                  enforce_self_pred=True)

In [None]:
# result vars
inp_r = results_neur["neur_input_input_pop_r"]
inp_ca = results_neur["neur_input_input_pop_ca"]
#inp_sp = results_spike["neur_input_input_pop"]

p_vb = [results_neur[f"neur_hidden{k}_pyr_pop_vb"] for k in range(len(N_HIDDEN))]
p_vbeff = [results_neur[f"neur_hidden{k}_pyr_pop_vbEff"] for k in range(len(N_HIDDEN))]
p_va = [results_neur[f"neur_hidden{k}_pyr_pop_va"] for k in range(len(N_HIDDEN))]
p_va_int = [results_neur[f"neur_hidden{k}_pyr_pop_va_int"] for k in range(len(N_HIDDEN))]
p_va_exc = [results_neur[f"neur_hidden{k}_pyr_pop_va_exc"] for k in range(len(N_HIDDEN))]
p_u = [results_neur[f"neur_hidden{k}_pyr_pop_u"] for k in range(len(N_HIDDEN))]
p_reff = [results_neur[f"neur_hidden{k}_pyr_pop_rEff"] for k in range(len(N_HIDDEN))]
p_ca = [results_neur[f"neur_hidden{k}_pyr_pop_ca"] for k in range(len(N_HIDDEN))]
p_b = [results_neur[f"neur_hidden{k}_pyr_pop_b"] for k in range(len(N_HIDDEN))]
#p_bias = [results_neur[f"neur_hidden{k}_pyr_pop_I_bias"] for k in range(len(N_HIDDEN))]

i_v = [results_neur[f"neur_hidden{k}_int_pop_v"] for k in range(len(N_HIDDEN))]
i_veff = [results_neur[f"neur_hidden{k}_int_pop_vEff"] for k in range(len(N_HIDDEN))]
i_va = [results_neur[f"neur_hidden{k}_int_pop_va"] for k in range(len(N_HIDDEN))]
i_u = [results_neur[f"neur_hidden{k}_int_pop_u"] for k in range(len(N_HIDDEN))]
i_reff = [results_neur[f"neur_hidden{k}_int_pop_rEff"] for k in range(len(N_HIDDEN))]
i_ca = [results_neur[f"neur_hidden{k}_int_pop_ca"] for k in range(len(N_HIDDEN))]
i_b = [results_neur[f"neur_hidden{k}_int_pop_b"] for k in range(len(N_HIDDEN))]
#i_bias = [results_neur[f"neur_hidden{k}_int_pop_I_bias"] for k in range(len(N_HIDDEN))]

out_vb = results_neur["neur_output_output_pop_vb"]
out_vbeff = results_neur["neur_output_output_pop_vbEff"]
out_va = results_neur["neur_output_output_pop_va"]
out_u = results_neur["neur_output_output_pop_u"]
out_vnudge = results_neur["neur_output_output_pop_vnudge"]
out_reff = results_neur["neur_output_output_pop_rEff"]
out_ca = results_neur["neur_output_output_pop_ca"]
out_b = results_neur["neur_output_output_pop_b"]
######################################################

val_out_r = results_validation[-1]["neur_var_rec"]["neur_output_output_pop_ca"]
val_in_r = results_validation[-1]["neur_var_rec"]["neur_input_input_pop_r"]

w_pinp = results_syn["syn_input_input_pop_to_hidden0_pyr_pop_g"]
w_ip = results_syn["syn_hidden0_pyr_pop_to_int_pop_g"]
w_pi = results_syn["syn_hidden0_int_pop_to_pyr_pop_g"]
w_pp = results_syn["syn_hidden0_pyr_pop_to_output_output_pop_g"]
w_ppb = results_syn["syn_output_output_pop_to_hidden0_pyr_pop_g"]

In [None]:
np.savez("data_sim.npz",
        # result vars
        inp_r = inp_r,
        inp_ca = inp_ca,
        #inp_sp = results_spike["neur_input_input_pop"]
        p_vb = p_vb,
        p_vbeff = p_vbeff,
        p_va = p_va,
        p_va_int = p_va_int,
        p_va_exc = p_va_exc,
        p_u = p_u,
        p_reff = p_reff,
        p_ca = p_ca,
        p_b = p_b,
        #p_bias = [results_neur[f"neur_hidden{k}_pyr_pop_I_bias"] for k in range(len(N_HIDDEN))]
        i_v = i_v,
        i_veff = i_veff,
        i_va = i_va,
        i_u = i_u,
        i_reff = i_reff,
        i_ca = i_ca,
        i_b = i_b,
        #i_bias = [results_neur[f"neur_hidden{k}_int_pop_I_bias"] for k in range(len(N_HIDDEN))]
        out_vb = out_vb,
        out_vbeff = out_vbeff,
        out_va = out_va,
        out_u = out_u,
        out_vnudge = out_vnudge,
        out_reff = out_reff,
        out_ca = out_ca,
        out_b = out_b,
        ######################################################
        val_out_r = val_out_r,
        val_in_r = val_in_r,
        #
        w_pinp = w_pinp,
        w_ip = w_ip,
        w_pi = w_pi,
        w_pp = w_pp,
        w_ppb = w_ppb)

In [None]:
data_results = np.load("data_sim.npz")

inp_r = data_results["inp_r"]
inp_ca = data_results["inp_ca"]
#inp_sp = results_spike["neur_input_input_pop"]
p_vb = data_results["p_vb"]
p_vbeff = data_results["p_vbeff"]
p_va = data_results["p_va"]
p_va_int = data_results["p_va_int"]
p_va_exc = data_results["p_va_exc"]
p_u = data_results["p_u"]
p_reff = data_results["p_reff"]
p_ca = data_results["p_ca"]
p_b = data_results["p_b"]
#p_bias = [results_neur[f"neur_hidden{k}_pyr_pop_I_bias"] for k in range(len(N_HIDDEN))]
i_v = data_results["i_v"]
i_veff = data_results["i_veff"]
i_va = data_results["i_va"]
i_u = data_results["i_u"]
i_reff = data_results["i_reff"]
i_ca = data_results["i_ca"]
i_b = data_results["i_b"]
#i_bias = [results_neur[f"neur_hidden{k}_int_pop_I_bias"] for k in range(len(N_HIDDEN))]
out_vb = data_results["out_vb"]
out_vbeff = data_results["out_vbeff"]
out_va = data_results["out_va"]
out_u = data_results["out_u"]
out_vnudge = data_results["out_vnudge"]
out_reff = data_results["out_reff"]
out_ca = data_results["out_ca"]
out_b = data_results["out_b"]
######################################################
val_out_r = data_results["val_out_r"]
val_in_r = data_results["val_in_r"]
#
w_pinp = data_results["w_pinp"]
w_ip = data_results["w_ip"]
w_pi = data_results["w_pi"]
w_pp = data_results["w_pp"]
w_ppb = data_results["w_ppb"]

In [None]:
p_va_pred = np.ndarray((N_REC,N_BATCH,N_HIDDEN[0]))
for k in range(N_REC):
    for l in range(N_BATCH):
        p_va_pred[k,l] = w_ppb[k] @ out_va[k,l]

In [None]:
p_va_exc_pred = np.ndarray((N_REC,N_BATCH,N_HIDDEN[0]))
for k in range(N_REC):
    for l in range(N_BATCH):
        p_va_exc_pred[k,l] = w_ppb[k] @ out_ca[k,l]

In [None]:
loss = np.ndarray((N_VAL_RUNS))

for k in range(N_VAL_RUNS):

    _ca_readout = results_validation[k]["neur_var_rec"]["neur_output_output_pop_ca"]

    loss[k] = calc_loss_interp(T_IND_UPDATE_PATTERNS_VAL,
                               T_IND_REC_VAL,
                               OUTPUT_DATA_VAL, _ca_readout,
                               perc_readout_targ_change=0.9)

In [None]:
fig_r, ax_r = plt.subplots(1, 3, figsize=(15, 5))

for k in range(3):
    cax = ax_r[k]
    for i in range(N_BATCH_VAL):
        cax.scatter(val_in_r[:, i, 0], val_in_r[:, i, 1], c=np.minimum(1.,val_out_r[:, i, k]))
        #cax.scatter(val_in_r[:, i, 0], val_in_r[:, i, 1], c=np.argmax(val_out_r[:, i, :],axis=1))

fig_r.savefig("class_pred.png",dpi=600)
        
plt.show()

In [None]:
fig_loss, ax_loss = plt.subplots(1, 1)

ax_loss.plot(T_IND_VAL_RUNS*DT, loss, '-o')
ax_loss.set_yscale("log")

ax_loss.set_ylabel("MSE")
ax_loss.set_xlabel(r'$t$')

fig_loss.savefig("loss.png",dpi=600)

plt.show()

In [None]:
plt.pcolormesh(p_ca[0][:,0,:].T)
plt.colorbar()
plt.show()

In [None]:
plt.plot(inp_ca[:,0,:])
plt.show()

In [None]:
fig_pred, ax_pred = plt.subplots(1, 3, figsize=(15,5))
for k in range(3):
    ax_pred[k].step(T_IND_UPDATE_PATTERNS_VAL*DT,OUTPUT_DATA_VAL[:, 0, k], where="post")
    ax_pred[k].plot(T_IND_REC_VAL*DT, results_validation[-1]["neur_var_rec"]["neur_output_output_pop_ca"][:, 0, k])
    ax_pred[k].set_xlabel(r'$t$')
    ax_pred[k].set_ylabel("Output Rate")

fig_pred.savefig("rate_pred.png",dpi=600)
    
plt.show()

In [None]:
plt.plot(p_va_pred[:,0,0]*0.5)
plt.plot(p_va[0][:,0,0])
plt.show()

In [None]:
plt.plot(out_va[:,0,0]*(1.-out_ca[:,0,0])**2.*np.minimum(1.,out_ca[:,0,0]*50.))
plt.plot(out_ca[:,0,0])
plt.show()

In [None]:
fig_w, ax_w = plt.subplots(1,1)

for readout in SYN_READOUT:
    _wrec = np.linalg.norm(results_syn[f'{readout[0]}_g'], axis=(1,2))
    ax_w.plot(_wrec, label = readout[0])

ax_w.legend()

fig_w.savefig("weight_norm.png",dpi=600)

plt.show()

In [None]:
plt.plot(w_pp[:,:,0])
plt.show()

In [None]:
plt.plot(p_)

In [None]:
w_pp.shape

In [None]:
w_align = np.ndarray((N_REC,N_OUT))
for k in range(N_REC):
    w_align[k] = (w_pp[k] * w_ppb[k].T).sum(axis=1)
    w_align[k] /= np.linalg.norm(w_pp[k],axis=1) * np.linalg.norm(w_ppb[k],axis=0)

In [None]:
fig_align, ax_align = plt.subplots()
ax_align.plot(w_align)

fig_align.savefig("fb_align.png",dpi=600)

plt.show()