<a href="https://colab.research.google.com/github/cahcharm/Neuro-and-Complexity-Science/blob/CSHA2021/project/project_spiking_circuit_model_for_working_memory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Spiking circuit model for working memory**

Building up a spiking circuit model for working memory (Misha, Science, 2008). 

* Step 1: build LIF neuron, voltage jump synapse, STP synapse
*	Step 2: consider the connectivity used in the paper, build up the network model
*	Step 3: try to encode one item in the model (Fig. 2)
*	Step 4: try to encode two items in the network (Fig. 3)


In [None]:
pip install brainpy-simulator

Collecting brainpy-simulator
  Downloading brainpy-simulator-1.0.2.tar.gz (126 kB)
[?25l[K     |██▋                             | 10 kB 20.5 MB/s eta 0:00:01[K     |█████▏                          | 20 kB 22.5 MB/s eta 0:00:01[K     |███████▉                        | 30 kB 25.5 MB/s eta 0:00:01[K     |██████████▍                     | 40 kB 24.0 MB/s eta 0:00:01[K     |█████████████                   | 51 kB 22.9 MB/s eta 0:00:01[K     |███████████████▋                | 61 kB 24.8 MB/s eta 0:00:01[K     |██████████████████▏             | 71 kB 20.0 MB/s eta 0:00:01[K     |████████████████████▊           | 81 kB 21.4 MB/s eta 0:00:01[K     |███████████████████████▍        | 92 kB 22.3 MB/s eta 0:00:01[K     |██████████████████████████      | 102 kB 22.9 MB/s eta 0:00:01[K     |████████████████████████████▋   | 112 kB 22.9 MB/s eta 0:00:01[K     |███████████████████████████████▏| 122 kB 22.9 MB/s eta 0:00:01[K     |████████████████████████████████| 126 kB 22.9 

In [None]:
import numpy as np
import brainpy as bp

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

* Step1: build LIF neuron, voltage jump synapse, STP synapse

In [None]:
## LIF neuron
class LIF(bp.NeuGroup):
  target_backend = ['numpy', 'numba']

  @staticmethod
  @bp.odeint(method='exponential_euler')
  def integral(V, t, Iext, V_rest, R, tau):
    dvdt = (-V + V_rest + R * Iext) / tau
    return dvdt

  def __init__(self, size, t_ref=1., V_rest=0., V_reset=0., 
               V_th=20., R=1., tau=10., **kwargs):
    super(LIF, self).__init__(size=size, **kwargs)
    
    # parameters
    self.V_rest = V_rest
    self.V_reset = V_reset
    self.V_th = V_th
    self.R = R
    self.tau = tau
    self.t_ref = t_ref

    # variables
    self.t_last_spike = bp.ops.ones(self.num) * -1e7
    self.refractory = bp.ops.zeros(self.num, dtype=bool)
    self.spike = bp.ops.zeros(self.num, dtype=bool)
    self.V = bp.ops.ones(self.num) * V_rest
    self.input = bp.ops.zeros(self.num)

  def update(self, _t):
    for i in range(self.num):
      spike = False
      refractory = (_t - self.t_last_spike[i]) <= self.t_ref
      if not refractory:
        V = self.integral(self.V[i], _t, self.input[i],
                          self.V_rest, self.R, self.tau)
        spike = (V >= self.V_th)
        if spike:
          V = self.V_reset
          self.t_last_spike[i] = _t
          refractory = True
        self.V[i] = V
      self.spike[i] = spike
      self.refractory[i] = refractory
      self.input[i] = 0.

In [None]:
## STP synapse
class STP(bp.TwoEndConn):
  target_backend = 'general'

  @staticmethod
  @bp.odeint(method='exponential_euler')
  def integral(s, u, x, t, tau, tau_d, tau_f):
    # Dynamics
    dsdt = - s / tau 
    dudt = - u / tau_f 
    dxdt = (1 - x) / tau_d
    return dsdt, dudt, dxdt
  
  def __init__(self, pre, post, conn, delay=0., U=0.15, tau_f=1500., tau_d=200., tau=8., A=1.,  **kwargs):
    # parameters
    self.tau_d = tau_d
    self.tau_f = tau_f
    self.tau = tau
    self.U = U
    self.delay = delay

    # connections
    self.conn = conn(pre.size, post.size)
    self.pre_ids, self.post_ids = conn.requires('pre_ids', 'post_ids')
    self.num = len(self.pre_ids)

    # variables
    self.s = bp.ops.zeros(self.num)
    self.x = bp.ops.ones(self.num)
    self.u = bp.ops.zeros(self.num)
    self.A = A
    self.I_syn = self.register_constant_delay('I_syn', size=self.num, delay_time=delay)
    
    super(STP, self).__init__(pre=pre, post=post, **kwargs)

  def update(self, _t):
    for i in range(self.num):
      pre_id, post_id = self.pre_ids[i], self.post_ids[i]

      self.s[i], u, x = self.integral(self.s[i], self.u[i], self.x[i], _t, self.tau, self.tau_d, self.tau_f)
      if self.pre.spike[pre_id]:
        # update if there is a spike
        u += self.U * (1 - self.u[i])
        self.s[i] += self.A * u * self.x[i]  
        x -= u * self.x[i]
      self.u[i] = u
      self.x[i] = x

      # output
      self.I_syn.push(i, self.s[i])
      self.post.input[post_id] += self.I_syn.pull(i)

In [None]:
## STD/STF parameters and plot 
neu1 = LIF(1, monitors=['V'])
neu2 = LIF(1, monitors=['V'])

# STD
syn = STP(U=0.2, tau_d=150., tau_f=2., pre=neu1, post=neu2, 
          conn=bp.connect.All2All(), monitors=['s', 'u', 'x'])
net = bp.Network(neu1, syn, neu2)
net.run(100., inputs=(neu1, 'input', 28.))

# plot
fig, gs = bp.visualize.get_figure(2, 1, 3, 7)

fig.add_subplot(gs[0, 0])
plt.plot(net.ts, syn.mon.u[:, 0], label='u')
plt.plot(net.ts, syn.mon.x[:, 0], label='x')
plt.legend()

fig.add_subplot(gs[1, 0])
plt.plot(net.ts, syn.mon.s[:, 0], label='s')
plt.legend()

plt.xlabel('Time (ms)')
plt.show()