# Fit AdapIF to HH Spike Trains

In [1]:
from brian2 import *
from brian2modelfitting import *

import numpy as np



In [2]:
dt = 0.01*ms
defaultclock.dt = dt

## Load the Data

In [3]:
voltage1 = genfromtxt('voltage1nA.csv', delimiter=',')
voltage2 = genfromtxt('voltage2nA.csv', delimiter=',')

In [4]:
inp_cur1 = genfromtxt('inp_cur1nA.csv', delimiter=',')
inp_cur2 = genfromtxt('inp_cur2nA.csv', delimiter=',')
inp_current = [inp_cur1, inp_cur2]
inp_current

[array([ 0.        , -0.01416314,  0.06263459, ..., -0.44478035,
        -0.39163954, -0.33014548]),
 array([ 0.        ,  0.39594954,  0.27878635, ..., -0.55344109,
        -0.71057181, -0.62256885])]

In [5]:
out_spikes1 = genfromtxt('out_spikes1nA.csv', delimiter=',')
out_spikes2 = genfromtxt('out_spikes2nA.csv', delimiter=',')
out_spikes = [out_spikes1, out_spikes2]
# out_spikes

## Model Fitting

In [6]:
# tau = 12.44*ms
# taut = 97.64*ms
# a = 0.21
# alpha = 7.2*mV
# R = 76*Mohm

model = '''
        dv/dt = (R*I- v)/tau :volt
        dvt/dt = (a*v - vt) / taut :volt
        alpha : volt (constant)
        a : 1 (constant)
        R : ohm (constant)
        tau : second (constant)
        taut : second (constant)
        D: second (constant)
        '''
reset = '''
v = 0*mV
vt = vt + alpha
'''

In [7]:
set_device('cpp_standalone', directory='parallel', clean=False)

In [8]:
n_opt = NevergradOptimizer()
metric = GammaFactor(time=50*second, delta=2*ms)

In [None]:
fitter = SpikeFitter(model=model, input_var='I', dt=dt,
                     input=inp_current * nA, output=out_spikes,
                     n_samples=1000,
                     threshold='v > vt',
                     reset=reset,
                     refractory='D',
                     )

In [None]:
result_dict, error = fitter.fit(n_rounds=5,
                                optimizer=n_opt,
                                metric=metric,
                                callback='text',
                                alpha=[1,20]*mV,
                                a = [0.1, 5],
                                R = [1, 500]*Mohm, 
                                tau = [9, 50]*ms,
                                taut = [60, 120]*ms,
                                D = [0.5, 2]*ms,
                               )

Round 0: fit (0.0011734163321673696, 1.395984835969274, 0.04087378166417265, 207866590.0330437, 0.08997833096837347, 0.014984551918012148) with error: 0.009117343637483885
Round 1: fit (0.0011734163321673696, 1.395984835969274, 0.04087378166417265, 207866590.0330437, 0.08997833096837347, 0.014984551918012148) with error: 0.009117343637483885
Round 2: fit (0.0011734163321673696, 1.395984835969274, 0.04087378166417265, 207866590.0330437, 0.08997833096837347, 0.014984551918012148) with error: 0.009117343637483885


In [None]:
result_dict

In [None]:
error

In [None]:
1 - error

In [None]:
device.reinit()
device.activate()

In [None]:
# visualization of the results
spikes = fitter.generate_spikes(params=None)
# print('spike times:', spikes)

In [None]:
print(len(out_spikes[0]))
print(len(spikes[0]))

In [None]:
print(len(out_spikes[1]))
print(len(spikes[1]))

In [None]:
d=-1

fig, ax = plt.subplots(nrows=2, figsize=(15,2))

ax[0].set_yticks(np.arange(0, 1, step=1))
ax[0].scatter(out_spikes[0][:d], np.ones_like(out_spikes[0][:d]));
ax[0].scatter(spikes[0][:d], np.ones_like(spikes[0][:d])*2);

ax[1].set_yticks(np.arange(0, 1, step=1))
ax[1].scatter(out_spikes[1][:d], np.ones_like(out_spikes[1][:d]));
ax[1].scatter(spikes[1][:d], np.ones_like(spikes[1][:d])*2);

In [None]:
d=500


fig, ax = plt.subplots(nrows=2, figsize=(15,2))

ax[0].set_yticks(np.arange(0, 1, step=1))
ax[0].scatter(out_spikes[0][:d], np.ones_like(out_spikes[0][:d]));
ax[0].scatter(spikes[0][:d], np.ones_like(spikes[0][:d])*2);

ax[1].set_yticks(np.arange(0, 1, step=1))
ax[1].scatter(out_spikes[1][:d], np.ones_like(out_spikes[1][:d]));
ax[1].scatter(spikes[1][:d], np.ones_like(spikes[1][:d])*2);

In [None]:
device.reinit()
device.activate()

In [None]:
fits = fitter.generate(params=None, output_var='v',)

In [None]:
# fig, ax = plt.subplots(nrows=2, figsize=(15,10))

# ax[0].plot(voltage1);
# ax[0].plot(fits[0]/mV)

# ax[1].plot(voltage2);
# ax[1].plot(fits[1]/mV);


In [None]:
t0 = np.arange(0, len(voltage1)*dt, dt)
t_ = spikes[0]
t = spikes[0][t_<3000]

In [None]:
t

In [None]:
v = fits[0]
d = 300000

In [None]:
fig, ax = plt.subplots(nrows=4, figsize=(14,10))
ax[0].plot(t0[1000:d]/ms, inp_cur1[1000:d]);
ax[1].plot(t0[1000:d]/ms, voltage1[1000:d]);
ax[2].plot(t0[1000:d]/ms, v[1000:d]/mV-60, 'g')
ax[2].vlines(t, v[np.int_(np.round(t/dt*ms))]/mV+10, v[np.int_(np.round(t/dt*ms))]/mV-60, 'g');

ax[3].plot(t0[1000:d]/ms, voltage1[1000:d]);
ax[3].plot(t0[1000:d]/ms, v[1000:d]/mV-60, 'r')
ax[3].vlines(t, v[np.int_(np.round(t/dt*ms))]/mV+10, v[np.int_(np.round(t/dt*ms))]/mV-60, 'r');



In [None]:
plt.figure(figsize=(10,5))
plot(t0[1000:d]/ms, v[1000:d]/mV-60, 'r')
vlines(t, v[np.int_(np.round(t/dt*ms))]/mV+60, v[np.int_(np.round(t/dt*ms))]/mV-60, 'r');
plot(t0[1000:d]/ms, voltage1[1000:d]);


In [None]:
param={'R': 38617749.58677548,
 'alpha': 0.006896191861644845,
 'taut': 0.0803213106795783,
 'tau': 0.013730585169469543,
 'a': 1.5549270966961934}
param