In [3]:
import brainpy as bp
import brainpy.math as bm
import numpy as np
import matplotlib.pyplot as plt
from brainpy.dyn import LifRef
import common_functions as cf
from scipy.sparse import csr_matrix, coo_matrix
import brainpy_functions as bf

# cf.set_gpu('0')
# cf.set_least_used_gpu()

# bp.math.set_platform('gpu')
bp.math.set_platform('cpu')
print(bp.__version__)

2.5.0


In [4]:
EI_ratio = 4
E2E_weight = 1
E2I_weight = 1
I2E_weight = -4
I2I_weight = -4
E_size = 4*400
I_size = E_size // EI_ratio

E_params = {'size': E_size, 'V_th': 20.0, 'V_reset': -5.0, 'V_rest':0., 'tau_ref': 5.0, 'R': 1.0, 'tau': 10.0}
I_params = {'size': I_size, 'V_th': 20.0, 'V_reset': -5.0, 'V_rest':0., 'tau_ref': 5.0, 'R': 1.0, 'tau': 10.0}
E2E_synapse_params = {'delay': 0}
E2I_synapse_params = {'delay': 0}
I2E_synapse_params = {'delay': 0}
I2I_synapse_params = {'delay': 0}

E_inp = 40
I_inp = 30

E_pos = np.meshgrid(np.linspace(0, 1, int(np.sqrt(E_size))), np.linspace(0, 1, int(np.sqrt(E_size))))
E_pos = np.stack([E_pos[0].flatten(), E_pos[1].flatten()], axis=1) + 0.1
I_pos = np.meshgrid(np.linspace(0, 1, int(np.sqrt(I_size))), np.linspace(0, 1, int(np.sqrt(I_size))))
I_pos = np.stack([I_pos[0].flatten(), I_pos[1].flatten()], axis=1)

dt = 1.
bm.set_dt(dt)

def zero_one_csr(row_indices, col_indices, shape):
    return csr_matrix((np.ones_like(row_indices), (row_indices, col_indices)), shape=shape)

def zero_one_conn(row_indices, col_indices, shape):
    return bp.connect.SparseMatConn(csr_mat=zero_one_csr(row_indices, col_indices, shape))

conn_num = 100

E2E_conn = zero_one_conn(np.random.randint(0, E_size, conn_num), np.random.randint(0, E_size, conn_num), (E_size, E_size))
E2E_conn = bp.connect.GaussianProb(sigma=0.1, pre=E_size, post=E_size)
E2E_comm = bp.dnn.EventCSRLinear(conn=E2E_conn, weight=E2E_weight)

E2I_conn = zero_one_conn(np.random.randint(0, E_size, conn_num), np.random.randint(0, I_size, conn_num), (E_size, I_size))
E2I_conn = bp.connect.GaussianProb(sigma=0.1, pre=E_size, post=I_size)
E2I_comm = bp.dnn.EventCSRLinear(conn=E2I_conn, weight=E2I_weight)

I2E_conn = zero_one_conn(np.random.randint(0, I_size, conn_num), np.random.randint(0, E_size, conn_num), (I_size, E_size))
I2E_conn = bp.connect.GaussianProb(sigma=0.1, pre=I_size, post=E_size)
I2E_comm = bp.dnn.EventCSRLinear(conn=I2E_conn, weight=I2E_weight)

I2I_conn = zero_one_conn(np.random.randint(0, I_size, conn_num), np.random.randint(0, I_size, conn_num), (I_size, I_size))
I2I_conn = bp.connect.GaussianProb(sigma=0.1, pre=I_size, post=I_size)
I2I_comm = bp.dnn.EventCSRLinear(conn=I2I_conn, weight=I2I_weight)

EI_net = bf.SpatialEINet(E_neuron=bp.dyn.LifRef, I_neuron=bp.dyn.LifRef, E_params=E_params, I_params=I_params, E2E_synapse=bp.dyn.FullProjDelta, E2I_synapse=bp.dyn.FullProjDelta, I2E_synapse=bp.dyn.FullProjDelta, I2I_synapse=bp.dyn.FullProjDelta, E2E_synapse_params=E2E_synapse_params, E2I_synapse_params=E2I_synapse_params, I2E_synapse_params=I2E_synapse_params, I2I_synapse_params=I2I_synapse_params, E2E_comm=E2E_comm, E2I_comm=E2I_comm, I2E_comm=I2E_comm, I2I_comm=I2I_comm, E_pos=E_pos, I_pos=I_pos)


def run_fun_1(i):
    local_E_inp = np.ones(E_size)*E_inp
    local_I_inp = np.ones(I_size)*I_inp
    return EI_net.step_run(i, local_E_inp, local_I_inp)

def run_fun_2(i):
    return EI_net.step_run(i, 0, 0)


indices_1 = np.arange(100)
ts_1 = indices_1 * bm.get_dt()
print(len(indices_1))
E_spikes_1, I_spikes_1, E_V_1, I_V_1 = bm.for_loop(
    run_fun_1, indices_1, progress_bar=True)

indices_2 = np.arange(100, 150)
ts_2 = indices_2 * bm.get_dt()
print(len(indices_2))
E_spikes_2, I_spikes_2, E_V_2, I_V_2 = bm.for_loop(
    run_fun_2, indices_2, progress_bar=True)

indices_3 = np.arange(150, 200)
ts_3 = indices_3 * bm.get_dt()
print(len(indices_3))
E_spikes_3, I_spikes_3, E_V_3, I_V_3 = bm.for_loop(
    run_fun_1, indices_3, progress_bar=True)


fig, ax = cf.create_fig_ax()
cf.plt_line(ax, ts_1, E_V_1[:, 0], label='E', color=cf.BLUE)
cf.plt_line(ax, ts_1, I_V_1[:, 0], label='I', color=cf.ORANGE)
cf.add_hline(ax, I_params['V_th'], label='Threshold', color=cf.RED)
cf.add_hline(ax, I_params['V_reset'], label='Reset', color=cf.GREEN)
cf.add_hline(ax, E2I_weight, label='E2I_weight', color=cf.BLACK)
cf.add_hline(ax, I2I_weight, label='I2I_weight', color=cf.PURPLE)
cf.set_ax(ax, 'Time (ms)', 'Membrane potential (mV)')

fig, ax = cf.create_fig_ax()
cf.plt_line(ax, np.concatenate([ts_1, ts_2, ts_3]), np.concatenate([E_V_1[:, 0], E_V_2[:, 0], E_V_3[:, 0]]), label='E', color=cf.BLUE)

bf.spike_video(E_spikes_1, E_pos, I_spikes_1, I_pos, dt, './spatial_EI_net/')



fr = bf.spike_to_fr(E_spikes_1, dt, dt, neuron_idx=slice(0, 10))


fig, ax = cf.create_fig_ax()
cf.plt_line(ax, ts_1, fr, label='Firing rate', color=cf.BLUE)



spike_lag_times, spike_acf = bf.get_spike_acf(E_spikes_1, dt, 10)
fr_lag_times, fr_acf = bf.spike_to_fr_acf(E_spikes_1, dt, dt, 10)


fig, ax = cf.create_fig_ax()
cf.plt_stem(ax, spike_lag_times, spike_acf, label='Spike ACF')

fig, ax = cf.create_fig_ax()
cf.plt_stem(ax, fr_lag_times, fr_acf, label='FR ACF')

100


  1%|          | 1/100 [00:00<00:52,  1.89it/s]

: 

In [None]:


# # set parameters

# num_inh = 1
# num_exc = 1
# prob = 0.25

# tau_E = 15.
# tau_I = 10.
# V_reset = -1.
# V_threshold = 15.
# V_rest = 0.
# f_E = 3.
# f_I = 2.
# mu_f = 6.

# tau_Es = 6.
# tau_Is = 5.
# JEE = 0.25
# JEI = -1.
# JIE = 0.4
# JII = -1.
# class ExponCUBA(bp.Projection):
#     def __init__(self, pre, post, prob, g_max, tau):
#         super().__init__()
#         self.proj = bp.dyn.ProjAlignPostMg2(
#             pre=pre,
#             delay=None,
#             comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max),
#             syn=bp.dyn.Expon.desc(post.num, tau=tau),
#             out=bp.dyn.CUBA.desc(),
#             post=post,
#         )


# class EINet(bp.DynSysGroup):
#     def __init__(self):
#         super().__init__()
#         # neurons
#         self.E = LifRef(num_exc, tau=tau_E, tau_ref=10, V_rest=V_rest, V_reset=V_reset, V_th=V_threshold)
#         self.I = LifRef(num_inh, tau=tau_I, tau_ref=10, V_rest=V_rest, V_reset=V_reset, V_th=V_threshold)
#         self.E.V[:] = 0.
#         self.I.V[:] = 0.

#         # synapses
#         # E2E_csr = csr_matrix((cf.repeat_data([1], num_exc), (np.arange(num_exc), np.arange(num_exc))), shape=(num_exc, num_exc))
#         # E2E_conn = bp.connect.SparseMatConn(csr_mat=E2E_csr)
#         # self.E2E = bp.dyn.FullProjDelta(self.E, 0., bp.dnn.EventCSRLinear(conn=E2E_conn, weight=1), self.E)

#         E2I_csr = csr_matrix((cf.repeat_data([1], num_inh), (np.arange(num_inh), np.arange(num_inh))), shape=(num_exc, num_inh))
#         E2I_conn = bp.connect.SparseMatConn(csr_mat=E2I_csr)
#         self.E2I = bp.dyn.FullProjDelta(self.E, 0., bp.dnn.EventCSRLinear(conn=E2I_conn, weight=2), self.I)

#         # I2E_csr = csr_matrix((cf.repeat_data([1], num_inh), (np.arange(num_inh), np.arange(num_inh))), shape=(num_inh, num_exc))
#         # I2E_conn = bp.connect.SparseMatConn(csr_mat=I2E_csr)
#         # self.I2E = bp.dyn.FullProjDelta(self.I, 0., bp.dnn.EventCSRLinear(conn=I2E_conn, weight=1), self.E)

#         # I2I_csr = csr_matrix((cf.repeat_data([1], num_inh), (np.arange(num_inh), np.arange(num_inh))), shape=(num_inh, num_inh))
#         # I2I_conn = bp.connect.SparseMatConn(csr_mat=I2I_csr)
#         # self.I2I = bp.dyn.FullProjDelta(self.I, 0., bp.dnn.EventCSRLinear(conn=I2I_conn, weight=1), self.I)

#         # self.E2I = ExponCUBA(self.E, self.I, prob, tau=tau_Es, g_max=JIE)
#         # self.E2E = ExponCUBA(self.E, self.E, prob, tau=tau_Es, g_max=JEE)
#         # self.I2I = ExponCUBA(self.I, self.I, prob, tau=tau_Is, g_max=JII)
#         # self.I2E = ExponCUBA(self.I, self.E, prob, tau=tau_Is, g_max=JEI)

#     def update(self, e_inp, i_inp):
#         # self.E2E()
#         self.E2I()
#         # self.I2E()
#         # self.I2I()
#         self.E(e_inp)
#         self.I(i_inp)

#         # monitor
#         return self.E.spike, self.I.spike, self.E.V, self.I.V


# net = EINet()


# def run_fun(i):
#     e_inp = f_E * mu_f
#     # i_inp = f_I * mu_f
#     i_inp = 0
#     return net.step_run(i, e_inp, i_inp)


# indices_1 = np.arange(1000)  # 100. ms
# print(len(indices_1))
# e_sps_1, i_sps_1, e_v_1, i_v_1 = bm.for_loop(
#     run_fun, indices_1, progress_bar=True)
# ts_1 = indices_1 * bm.get_dt()

# # 再run一次
# indices_2 = np.arange(1000, 2000)  # 100. ms
# print(len(indices_2))
# e_sps_2, i_sps_2, e_v_2, i_v_2 = bm.for_loop(
#     run_fun, indices_2, progress_bar=True)
# ts_2 = indices_2 * bm.get_dt()

# fig, ax = cf.create_fig_ax()
# cf.plt_line_plot(ax, ts_1, e_v_1[:, 0], label='E', color=cf.DEFAULT_BLUE)
# cf.plt_line_plot(ax, ts_1, i_v_1[:, 0], label='I', color=cf.DEFAULT_ORANGE)
# cf.add_hline(ax, V_threshold, label='Threshold', color=cf.DEFAULT_RED)
# cf.add_hline(ax, V_reset, label='Reset', color=cf.DEFAULT_GREEN)
# cf.add_hline(ax, 1, label='one spike weight', color=cf.DEFAULT_BLACK)
# cf.set_ax(ax, 'Time (ms)', 'Membrane potential (mV)')