In [57]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from scipy.stats import zscore
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

In [204]:
rs = {'a': 0.02, 'b': 0.20, 'c': -65.0, 'd': 8.00}
ib = {'a': 0.02, 'b': 0.20, 'c': -55.0, 'd': 4.00}
ch = {'a': 0.02, 'b': 0.20, 'c': -50.0, 'd': 2.00}
fs = {'a': 0.10, 'b': 0.20, 'c': -65.0, 'd': 2.00}
th = {'a': 0.02, 'b': 0.25, 'c': -65.0, 'd': 0.05}
res = {'a': 0.10, 'b': 0.25, 'c': -65.0, 'd': 2.00}
lts = {'a': 0.02, 'b': 0.25, 'c': -65.0, 'd': 2.00}
nrns={
    'rs':rs,
    'in':ib,
    'ch':ch,
    'fs':fs,
    'th':th,
    'res':res,
    'lts':lts,
}

N = len(nrns)
T = 5000 # length in samples
dt = 0.1

V_0 = -70
u_0 = -14
V_spike = 35

V = np.zeros((T,N))
u = np.zeros((T,N))
V[0,...] = V_0*np.ones(N)
u[0,...] = u_0*np.ones(N)

a = [nrns[k]['a'] for k in nrns]
b = [nrns[k]['b'] for k in nrns]
c = [nrns[k]['c'] for k in nrns]
d = [nrns[k]['d'] for k in nrns]

spikes = np.zeros((T,N))
stim = np.zeros(T)

In [202]:
def run_l1():
    for t in range(1, T):
        noise = np.random.randint(8,10,7)/10
        # if we still didnt reach spike potential
        for i in range(len(V[t-1])):
            stim[t] = (1+np.sin(t*dt**2*noise[i]))
            if V[t-1,i] < V_spike:
                # ODE for membrane potential
                dV      = (0.04 * V[t-1,i] + 5) * V[t-1,i] + 140 - u[t-1,i]
                V[t,i]    = V[t-1,i] + (dV + 3*stim[t]) * dt # these cells have a low base input... lol
                # ODE for recovery variable
                du      = a[i] * (b[i] * V[t-1,i] - u[t-1,i])
                u[t,i]    = u[t-1,i] + dt * du
            # spike reached!
            else:
                V[t-1,i] = V_spike    # set to spike value
                V[t,i] = c[i]                # reset membrane voltage
                u[t,i] = u[t-1,i] + d[i]       # reset recovery
                spikes[t-1,i] = 1

In [203]:
run_l1()
%matplotlib qt
fig,ax = plt.subplot_mosaic('AB')
ax['A'].plot(V,alpha=0.5)
ax['B'].scatter(*np.where(spikes==1),color='black')
ax['B'].plot(stim,alpha=0.5)

[<matplotlib.lines.Line2D at 0x2c5acc78f10>]