In [2]:
import numpy as np
import matplotlib.pyplot as plt
import keras.backend as K

# Define the Izhikevich neuron model


def izhikevich(v, u, a, b, c, d, dt, I):
    v_new = v + dt * (0.04 * v**2 + 5 * v + 140 - u + I)
    u_new = u + dt * a * (b * v - u)
    spike = K.cast(v_new >= 30.0, dtype='float32')
    v_new = (1 - spike) * v_new + spike * c
    u_new = (1 - spike) * u_new + spike * (u + d)
    return v_new, u_new, spike


# Simulate a single Izhikevich neuron
v = -65.0  # initial membrane potential (mV)
u = 0.2 * v  # initial recovery variable (pA)
a = 0.02  # time scale of recovery variable (ms)
b = 0.2  # sensitivity of recovery to subthreshold fluctuations
c = -65.0  # reset value of membrane potential after spike (mV)
d = 8.0  # reset value of recovery variable after spike (pA)
dt = 0.1  # time step size (ms)
I = 50.0  # input current amplitude (pA)
t_max = 1000.0  # maximum simulation time (ms)

n_steps = int(t_max / dt)
t = np.arange(n_steps) * dt
spikes = np.zeros(n_steps)


In [3]:
for i in range(n_steps):
    v, u, spike = izhikevich(v, u, a, b, c, d, dt, I)
    spikes[i] = spike



In [4]:
print(spikes)

[0. 0. 0. ... 0. 0. 0.]


In [None]:

# Plot the membrane potential and spike train
fig, ax = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(8, 6))
ax[0].plot(t, v, label='membrane potential')
ax[0].set_ylabel('Voltage (mV)')
ax[0].legend()
ax[1].plot(t, spikes, label='spike train')
ax[1].set_xlabel('Time (ms)')
ax[1].set_ylabel('Spikes')
ax[1].legend()
plt.show()
