<a href="https://colab.research.google.com/github/cahcharm/Neuro-and-Complexity-Science/blob/CSHA2021/project/project_spiking_circuit_model_for_working_memory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Spiking circuit model for working memory**

Building up a spiking circuit model for working memory (Misha, Science, 2008). 

* Step 1: build LIF neuron, voltage jump synapse, STP synapse
*	Step 2: consider the connectivity used in the paper, build up the network model
*	Step 3: try to encode one item in the model (Fig. 2)
*	Step 4: try to encode two items in the network (Fig. 3)


In [None]:
pip install brainpy-simulator

Collecting brainpy-simulator
  Downloading brainpy-simulator-1.0.2.tar.gz (126 kB)
[?25l[K     |██▋                             | 10 kB 20.5 MB/s eta 0:00:01[K     |█████▏                          | 20 kB 22.5 MB/s eta 0:00:01[K     |███████▉                        | 30 kB 25.5 MB/s eta 0:00:01[K     |██████████▍                     | 40 kB 24.0 MB/s eta 0:00:01[K     |█████████████                   | 51 kB 22.9 MB/s eta 0:00:01[K     |███████████████▋                | 61 kB 24.8 MB/s eta 0:00:01[K     |██████████████████▏             | 71 kB 20.0 MB/s eta 0:00:01[K     |████████████████████▊           | 81 kB 21.4 MB/s eta 0:00:01[K     |███████████████████████▍        | 92 kB 22.3 MB/s eta 0:00:01[K     |██████████████████████████      | 102 kB 22.9 MB/s eta 0:00:01[K     |████████████████████████████▋   | 112 kB 22.9 MB/s eta 0:00:01[K     |███████████████████████████████▏| 122 kB 22.9 MB/s eta 0:00:01[K     |████████████████████████████████| 126 kB 22.9 

In [None]:
import numpy as np
import brainpy as bp

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

## Step 1: build LIF neuron, voltage jump synapse, STP synapse

**LIF Neuron**

In [None]:
class LIF(bp.NeuGroup):
  target_backend = ['numpy', 'numba']

  @staticmethod
  @bp.odeint(method='exponential_euler')
  def integral(V, t, Iext, V_rest, R, tau):
    dvdt = (-V + V_rest + R * Iext) / tau
    return dvdt

  def __init__(self, size, t_ref=1., V_rest=0., V_reset=0., 
               V_th=20., R=1., tau=10., **kwargs):
    super(LIF, self).__init__(size=size, **kwargs)
    
    # parameters
    self.V_rest = V_rest
    self.V_reset = V_reset
    self.V_th = V_th
    self.R = R
    self.tau = tau
    self.t_ref = t_ref

    # variables
    self.t_last_spike = bp.ops.ones(self.num) * -1e7
    self.refractory = bp.ops.zeros(self.num, dtype=bool)
    self.spike = bp.ops.zeros(self.num, dtype=bool)
    self.V = bp.ops.ones(self.num) * V_rest
    self.input = bp.ops.zeros(self.num)

  def update(self, _t):
    for i in range(self.num):
      spike = False
      refractory = (_t - self.t_last_spike[i]) <= self.t_ref
      if not refractory:
        V = self.integral(self.V[i], _t, self.input[i],
                          self.V_rest, self.R, self.tau)
        spike = (V >= self.V_th)
        if spike:
          V = self.V_reset
          self.t_last_spike[i] = _t
          refractory = True
        self.V[i] = V
      self.spike[i] = spike
      self.refractory[i] = refractory
      self.input[i] = 0.

**Voltage Jump Synapse**

$$
I = \sum_{j\in C} g \delta(t-t_j-D)
$$

where $g$ denotes the chemical synaptic strength, $t_j$ the spiking moment of the presynaptic neuron $j$, $C$ the set of neurons in the encoding layer, and $D$ the transmission delay of chemical synapses. For simplicity, we omit the rise and decay phases of post-synaptic currents. 

In [None]:
class VoltageJump(bp.TwoEndConn):
    target_backend = ['numpy', 'numba']

    def __init__(self, pre, post, conn, delay=0., post_refractory=False, weight=1., **kwargs):
        # parameters
        self.delay = delay
        self.post_has_refractory = post_refractory

        # connections
        self.conn = conn(pre.size, post.size)
        self.pre_ids, self.post_ids = conn.requires('pre_ids', 'post_ids')
        self.num = len(self.pre_ids)

        # variables
        self.s = np.zeros(self.num)
        self.w = np.ones(self.num) * weight
        self.I_syn = self.register_constant_delay('I_syn', size=self.num, delay_time=delay)

        super(VoltageJump, self).__init__(pre=pre, post=post, **kwargs)
        
        # checking
        assert hasattr(pre, 'V'), 'Pre-synaptic group must has "V" variable.'
        assert hasattr(post, 'V'), 'Post-synaptic group must has "V" variable.'
        assert hasattr(post, 'input'), 'Post-synaptic group must has "input" variable.'
        if post_refractory:
            assert hasattr(post, 'refractory'), 'Post-synaptic group must has "refractory" variable.'

    def update(self, _t):
        for i in range(self.num):
            pre_id = self.pre_ids[i]
            post_id = self.post_ids[i]
            # update
            self.s[i] = self.pre.spike[pre_id]
            self.I_syn.push(i, self.s[i] * self.w[i])
            # output
            I_syn = self.I_syn.pull(i)
            if self.post_has_refractory:
                self.post.V += I_syn * (1. - self.post.refractory[post_id])
            else:
                self.post.V += I_syn

In [None]:
neu1 = LIF(1, monitors=['V', 'spike'])
neu2 = LIF(1, monitors=['V'])

syn1 = VoltageJump(pre=neu1, post=neu2, conn=bp.connect.All2All(), delay=2.0)

net = bp.Network(neu1, syn1, neu2)
net.run(150., inputs=[(neu1, 'input', 25.), (neu2, 'input', 10.)])

fig, gs = bp.visualize.get_figure(1, 1, 3, 8)
plt.plot(net.ts, neu1.mon.V, label='pre-V')
plt.plot(net.ts, neu2.mon.V, label='post-V')
plt.xlim(40, 150)
plt.legend()


**STP Synapse**

$$\frac{d u}{d t}= -\frac{u}{\tau_{f}}+U\left(1-u^{-}\right) \delta\left(t-t_{s p}\right)$$

$$\frac{d x}{d t}= \frac{1-x}{\tau_{d}}-u^{+} x^{-} \delta\left(t-t_{s p}\right)$$

$$\frac{d I}{d t}= -\frac{I}{\tau_{s}}+A u^{+} x^{-} \delta\left(t-t_{s p}\right)$$

$$u^{+}=u^{-}+U\left(1-u^{-}\right)$$

Or we can see the dynamics as:

$$
\frac {du} {dt} = - \frac u {\tau_f} 
$$

$$
\frac {dx} {dt} =  \frac {1-x} {\tau_d} 
$$

$$
\frac {dI} {dt} = - \frac I {\tau}
$$

$$
\rm{if (pre \ fire), then}
\begin{cases} 
u^+ = u^- + U(1-u^-) \\ 
I^+ = I^- + Au^+x^- \\
x^+ = x^- - u^+x^- 
\end{cases}
$$

In [None]:
class STP(bp.TwoEndConn):
  target_backend = 'general'

  @staticmethod
  @bp.odeint(method='exponential_euler')
  def integral(s, u, x, t, tau, tau_d, tau_f):
    # Dynamics
    dsdt = - s / tau 
    dudt = - u / tau_f 
    dxdt = (1 - x) / tau_d
    return dsdt, dudt, dxdt
  
  def __init__(self, pre, post, conn, delay=0., U=0.15, tau_f=1500., tau_d=200., tau=8., A=1.,  **kwargs):
    # parameters
    self.tau_d = tau_d
    self.tau_f = tau_f
    self.tau = tau
    self.U = U
    self.delay = delay

    # connections
    self.conn = conn(pre.size, post.size)
    self.pre_ids, self.post_ids = conn.requires('pre_ids', 'post_ids')
    self.num = len(self.pre_ids)

    # variables
    self.s = bp.ops.zeros(self.num)
    self.x = bp.ops.ones(self.num)
    self.u = bp.ops.zeros(self.num)
    self.A = A
    self.I_syn = self.register_constant_delay('I_syn', size=self.num, delay_time=delay)
    
    super(STP, self).__init__(pre=pre, post=post, **kwargs)

  def update(self, _t):
    for i in range(self.num):
      pre_id, post_id = self.pre_ids[i], self.post_ids[i]

      self.s[i], u, x = self.integral(self.s[i], self.u[i], self.x[i], _t, self.tau, self.tau_d, self.tau_f)
      if self.pre.spike[pre_id]:
        # update if there is a spike
        u += self.U * (1 - self.u[i])
        self.s[i] += self.A * u * self.x[i]  
        x -= u * self.x[i]
      self.u[i] = u
      self.x[i] = x

      # output
      self.I_syn.push(i, self.s[i])
      self.post.input[post_id] += self.I_syn.pull(i)

In [None]:
## STD/STF parameters and plot 
neu1 = LIF(1, monitors=['V'])
neu2 = LIF(1, monitors=['V'])

# STD
syn = STP(U=0.2, tau_d=150., tau_f=2., pre=neu1, post=neu2, 
          conn=bp.connect.All2All(), monitors=['s', 'u', 'x'])
net = bp.Network(neu1, syn, neu2)
net.run(100., inputs=(neu1, 'input', 28.))

# plot
fig, gs = bp.visualize.get_figure(2, 1, 3, 7)

fig.add_subplot(gs[0, 0])
plt.plot(net.ts, syn.mon.u[:, 0], label='u')
plt.plot(net.ts, syn.mon.x[:, 0], label='x')
plt.legend()

fig.add_subplot(gs[1, 0])
plt.plot(net.ts, syn.mon.s[:, 0], label='s')
plt.legend()

plt.xlabel('Time (ms)')
plt.show()

## Step 2: consider the connectivity used in the paper, build up the network model

In [None]:
dt = 0.0001  # [s]
bp.backend.set(dt=dt)

In [None]:
# the parameters of network
alpha = 1.5
J_EE = 8.  # the connection strength in each excitatory neural clusters
J_IE = 1.75  # Synaptic efficacy E → I
J_EI = 1.1  # Synaptic efficacy I → E
tau_f = 1.5  # time constant of STF  [s]
tau_d = .3  # time constant of STD  [s]
U = 0.3  # minimum STF value
tau = 0.008  # time constant of firing rate of the excitatory neurons [s]
tau_I = tau  # time constant of firing rate of the inhibitory neurons

Ib = 8.  # background input and external input
Iinh = 0.  # the background input of inhibtory neuron

cluster_num = 16  # the number of the clusters

In [None]:
# the parameters of external input

stimulus_num = 5
Iext_train = 225  # the strength of the external input
# the time interval between the consequent external input [s]
Ts_interval = 0.070
Ts_duration = 0.030  # the time duration of the external input [s]
duration = 2.500  # [s]

the working memory model based on STP was used to derive the expression for the postsynaptic current resulting from the activity of a large, uncorrelated pre-synaptic population.

The resulting network model has three differential equations for each of $P$ excitatory clusters (synaptic current $h_\mu$ and two STP variables $u_\mu$ and $x_\mu$ for each cluster $\mu; \mu = 1,..., P$) and one additional equation for the inhibitory pool current $h_I$:

$$
\begin{gathered}
\tau \frac{d h_{\mu}}{d t}=-h_{\mu}+J_{E E} u_{\mu} x_{\mu} R_{\mu}-J_{E l} R_{l}+I_{b}+I_{e}(t) \\
\frac{d u_{\mu}}{d t}=\frac{U-u_{\mu}}{\tau_{f}}+U\left(1-u_{\mu}\right) R_{\mu} \\
\frac{d x_{\mu}}{d t}=\frac{1-x_{\mu}}{\tau_{d}}-u_{\mu} x_{\mu} R_{\mu}, \text { and } \\
\tau \frac{d h_{I}}{d t}=-h_{I}+J_{I E} \sum_{\nu} R_{\nu}
\end{gathered}
$$

where $t$ is the neuronal time constant, for simplicity the same for excitation and inhibition; $I_b$ is the constant background excitation; and $I_e$ is the external input used to load memory items into the network. 

$$
R(h)=\alpha \ln (1+\exp (h / \alpha))
$$

is neuronal gain chosen in the form of a smoothed threshold-linear function, also the same for excitatory and inhibitory neurons. 

In [None]:
# the excitatory cluster model and the inhibitory pool model

class WorkingMemoryModel(bp.NeuGroup):
  target_backend = ['numpy', 'numba']

  def __init__(self, size, **kwargs):
    self.inh_h = 0.
    self.inh_r = self.log(self.inh_h)
    self.u = bp.ops.ones(cluster_num) * U
    self.x = bp.ops.ones(cluster_num)
    self.h = bp.ops.zeros(cluster_num)
    self.r = self.log(self.h)
    self.input = bp.ops.zeros(cluster_num)

    super(WorkingMemoryModel, self).__init__(size, **kwargs)

  @staticmethod
  @bp.odeint
  def int_exc(u, x, h, t, r, r_inh, Iext):
    du = (U - u) / tau_f + U * (1 - u) * r
    dx = (1 - x) / tau_d - u * x * r
    dh = (-h + J_EE * u * x * r - J_EI * r_inh + Iext + Ib) / tau
    return du, dx, dh

  @staticmethod
  @bp.odeint
  def int_inh(h, t, r_exc):
    h_I = (-h + J_IE * np.sum(r_exc) + Iinh) / tau_I
    return h_I

  @staticmethod
  def log(h):
    return alpha * np.log(1. + np.exp(h / alpha))

  def update(self, _t):
    self.u, self.x, self.h = self.int_exc(
        self.u, self.x, self.h, _t, self.r, self.inh_r, self.input)
    self.r = self.log(self.h)
    self.inh_h = self.int_inh(self.inh_h, _t, self.r)
    self.inh_r = self.log(self.inh_h)
    self.input[:] = 0.

In [None]:
# the external input

I_inputs = np.zeros((int(duration / dt), cluster_num))
for i in range(stimulus_num):
    t_start = (Ts_interval + Ts_duration) * i + Ts_interval
    t_end = t_start + Ts_duration
    idx_start, idx_end = int(t_start / dt), int(t_end / dt)
    I_inputs[idx_start: idx_end, i] = Iext_train


# model.monwork running

model = WorkingMemoryModel(cluster_num, monitors=['u', 'x', 'r', 'h'])
model.run(duration, inputs=['input', I_inputs])

In [None]:
# visualization

colors = list(dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS).keys())

fig, gs = bp.visualize.get_figure(5, 1, 2, 12)
fig.add_subplot(gs[0, 0])
for i in range(stimulus_num):
    plt.plot(model.mon.ts, model.mon.r[:, i], label='Cluster-{}'.format(i))
plt.ylabel("$r (Hz)$")
plt.legend(loc='right')

fig.add_subplot(gs[1, 0])
hist_Jux = J_EE * model.mon.u * model.mon.x
for i in range(stimulus_num):
    plt.plot(model.mon.ts, hist_Jux[:, i])
plt.ylabel("$J_{EE}ux$")

fig.add_subplot(gs[2, 0])
for i in range(stimulus_num):
    plt.plot(model.mon.ts, model.mon.u[:, i], colors[i])
plt.ylabel('u')

fig.add_subplot(gs[3, 0])
for i in range(stimulus_num):
    plt.plot(model.mon.ts, model.mon.x[:, i], colors[i])
plt.ylabel('x')

fig.add_subplot(gs[4, 0])
for i in range(stimulus_num):
    plt.plot(model.mon.ts, model.mon.r[:, i], colors[i])
plt.ylabel('h')
plt.xlabel('time [s]')

plt.show()

## Step 3: try to encode one item in the model (Fig. 2)

## Step 4: try to encode two items in the network (Fig. 3)