In [2]:
from rnn import QIFExpAddSyns
from rnn import mQIFExpAddSynsRNN
import numpy as np
import pickle
from scipy.ndimage import gaussian_filter1d


def kuramoto_order_parameter(r, v):
    W = np.asarray([complex(np.pi * r_tmp, v_tmp) for r_tmp, v_tmp in zip(r, v)])
    W_c = W.conjugate()
    return np.abs((1 - W_c) / (1 + W_c))


# STEP 0: Define simulation condition
#####################################

# parse worker indices from script arguments
idx_cond = 1

# STEP 1: Load pre-generated RNN parameters
###########################################

config = pickle.load(open("qif_input_config.pkl", 'rb'))

# connectivity matrix
C = config['C']

# input
inp = config['inp']

# input weights
W_in = config['W_in']

# simulation config
T = config['T']
dt = config['dt']
dts = config['dts']
cutoff = config['cutoff']
t = int((T - cutoff)/dts)
M = config['number_input_channels']

# target values
y = config['targets']

# STEP 1: define remaining network parameters
#############################################

# general parameters
N = C.shape[0]
m = W_in.shape[0]
n_folds = 5
ridge_alpha = 0.5*10e-3

# qif parameters
Delta = 0.3
eta = -0.2
tau_a = 10.0
tau_s = 0.5

# adaptation strength
alpha = 0.3

# independent variable (IV)
iv_name = "J"
n_iv = 2
ivs = np.linspace(0, 20, num=n_iv)

# mean-field parameters
C_mf = np.ones(shape=(1, 1))
inp_mf = np.zeros_like(inp)
in_start = int(0.25*(T - cutoff)/dt) + int(cutoff/dt)
in_dur = int(0.2/dt)
inp_mf[:, in_start:in_start+in_dur] = 1.0
W_in_mf = np.ones((1, M)) * 1.0

# STEP 3: Evaluate classification performance of RNN
####################################################

data = dict()
data["score"] = np.zeros((n_iv,))
data["wta_score"] = np.zeros_like(data["score"])
data["r_qif"] = np.zeros((n_iv, t))
data["v_qif"] = np.zeros_like(data["r_qif"])
data["r_mf"] = np.zeros_like(data["r_qif"])
data["v_mf"] = np.zeros_like(data["r_qif"])
data["Z_qif"] = np.zeros_like(data["r_qif"])
data["Z_mf"] = np.zeros_like(data["r_qif"])
data["iv"] = ivs
data["iv_name"] = iv_name

# simulation loop for j's
for j in range(n_iv):

    iv = ivs[j]
    print(f'Performing simulations for {iv_name} = {iv} ...')

    # setup QIF RNN
    qif_rnn = QIFExpAddSyns(C, eta, iv, Delta=Delta, alpha=alpha, tau_s=tau_s, tau_a=tau_a, tau=1.0)

    # perform simulation
    results = qif_rnn.run(T, dt, dts, inp=inp, W_in=W_in, cutoff=cutoff, outputs=(np.arange(0, N), np.arange(3*N, 4*N)))
    v_qif = np.mean(results[0], axis=1)
    r_qif = np.mean(results[1], axis=1)
    X = results[1]

    # prepare training data
    buffer_val = 0
    for i in range(X.shape[1]):
        X[:, i] = gaussian_filter1d(X[:, i], 0.1/dts, mode='constant', cval=buffer_val)
    r_qif2 = np.mean(X, axis=1)

    # split into test and training data
    split = int(np.round(X.shape[0]*0.75, decimals=0))
    X_train = X[:split, :]
    y_train = y[:split, :]
    X_test = X[split:, :]
    y_test = y[split:, :]

    # train RNN
    key, scores, coefs = qif_rnn.ridge_fit(X=X_train, y=y_train, alpha=ridge_alpha, k=0, fit_intercept=False,
                                           copy_X=True, solver='lsqr', readout_key=f'qif_m{M}', verbose=False)

    # calculate classification score on test data
    score, y_predict = qif_rnn.test(X=X_test, y=y_test, readout_key=key)

    # Winner takes it all classification
    wta_pred = y_predict.argmax(axis=1)
    wta_target = y_test.argmax(axis=1)
    wta_score = np.mean(wta_pred == wta_target)

    # simulate mean-field dynamics
    qif_mf = mQIFExpAddSynsRNN(C_mf, eta, iv, Delta=Delta, tau=1.0, alpha=alpha, tau_a=tau_a, tau_s=tau_s)
    results = qif_mf.run(T, dt, dts, inp=inp_mf, W_in=W_in_mf, cutoff=cutoff, outputs=([0], [1]))
    v_mf = np.squeeze(results[0])
    r_mf = np.squeeze(results[1])

    # calculate Kuramoto order parameter Z for QIF network and mean-field model
    Z_qif = kuramoto_order_parameter(r_qif, v_qif)
    Z_mf = kuramoto_order_parameter(r_mf, v_mf)

    print(f"Finished. Results: WTA = {wta_score}, mean(Z) = {np.mean(Z_qif)}.")

    # store data
    data["score"][j] = score
    data["wta_score"][j] = wta_score
    data["r_qif"][j, :] = r_qif2
    data["v_qif"][j, :] = v_qif
    data["r_mf"][j, :] = r_mf
    data["v_mf"][j, :] = v_mf
    data["Z_qif"][j, :] = Z_qif
    data["Z_mf"][j, :] = Z_mf

data["T"] = T
pickle.dump(data, open('qif_rc_multichannel_results2.pkl', 'wb'))


Performing simulations for J = 0.0 ...
Finished simulation. The state recordings are available under `state_records`.


TypingError: Failed in nopython mode pipeline (step: nopython frontend)
[1m[1m[1mNo implementation of function Function(<built-in function getitem>) found for signature:
 
 >>> getitem(array(float64, 1d, C), Tuple(Literal[int](0), slice<a:b>))
 
There are 22 candidate implementations:
[1m      - Of which 20 did not match due to:
      Overload of function 'getitem': File: <numerous>: Line N/A.
        With argument(s): '(array(float64, 1d, C), Tuple(int64, slice<a:b>))':[0m
[1m       No match.[0m
[1m      - Of which 1 did not match due to:
      Overload in function 'GetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 162.
        With argument(s): '(array(float64, 1d, C), Tuple(int64, slice<a:b>))':[0m
[1m       Rejected as the implementation raised a specific error:
         TypeError: cannot index array(float64, 1d, C) with 2 indices: Tuple(int64, slice<a:b>)[0m
  raised from /Users/willi/opt/anaconda3/lib/python3.8/site-packages/numba/core/typing/arraydecl.py:84
[1m      - Of which 1 did not match due to:
      Overload in function 'GetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 162.
        With argument(s): '(array(float64, 1d, C), Tuple(Literal[int](0), slice<a:b>))':[0m
[1m       Rejected as the implementation raised a specific error:
         TypeError: cannot index array(float64, 1d, C) with 2 indices: Tuple(Literal[int](0), slice<a:b>)[0m
  raised from /Users/willi/opt/anaconda3/lib/python3.8/site-packages/numba/core/typing/arraydecl.py:84
[0m
[0m[1mDuring: typing of intrinsic-call at /Users/willi/Documents/Masterarbeit/PycharmProjects/Github/rnn.py (446)[0m
[0m[1mDuring: typing of static-get-item at /Users/willi/Documents/Masterarbeit/PycharmProjects/Github/rnn.py (446)[0m
[1m
File "rnn.py", line 446:[0m
[1m        def mqif_update(u: np.ndarray, inp: np.ndarray, W_in: np.ndarray, N: int, C: np.ndarray, eta: np.ndarray,
            <source elided>
            #net_inp = J * s * tau
[1m            net_inp = J*s[0, :]*tau
[0m            [1m^[0m[0m
