In [None]:
!pip install jax[tpu]==0.4.11 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

In [None]:
import jax
print(jax.__version__)
print(jax.devices())

In [None]:
!pip install brainpy

In [None]:
import time

import brainpy as bp
import brainpy.math as bm
import numpy as np

bm.set_dt(0.1)

In [None]:
taum = 20
taue = 5
taui = 10
Vt = -50
Vr = -60
El = -60
Erev_exc = 0.
Erev_inh = -80.
Ib = 20.
ref = 5.0
we = 0.6
wi = 6.7

In [None]:
class LIF(bp.dyn.NeuDyn):
    def __init__(self, size, V_init: callable, sharding=None):
        super(LIF, self).__init__(size=size, sharding=sharding)

        # parameters
        self.V_rest = Vr
        self.V_reset = El
        self.V_th = Vt
        self.tau = taum
        self.tau_ref = ref

        # variables
        self.V = self.init_variable(V_init, self.mode)
        self.spike = self.init_variable(lambda s: bm.zeros(s, dtype=bool), self.mode)
        self.t_last_spike = self.init_variable(lambda s: bm.ones(s) * -1e7, self.mode)

    def update(self, inp):
        inp = self.sum_inputs(self.V.value, init=inp)  # sum all projection inputs
        refractory = (bp.share['t'] - self.t_last_spike) <= self.tau_ref
        V = self.V + (-self.V + self.V_rest + inp) / self.tau * bp.share['dt']
        V = bm.where(refractory, self.V, V)
        spike = self.V_th <= V
        self.t_last_spike.value = bm.where(spike, bp.share['t'], self.t_last_spike)
        self.V.value = bm.where(spike, self.V_reset, V)
        self.spike.value = spike
        return spike

In [None]:
class MaskedLinear(bp.dnn.Layer):
  def __init__(self, num_pre, num_post, prob, weight, sharding=None):
    super().__init__()
    print('Using masked linear')
    self.weight = weight
    f = bm.jit(
        lambda key: jax.random.bernoulli(key, prob, (num_pre, num_post)),
        out_shardings=bm.sharding.get_sharding(sharding),
    )
    self.mask = f(bm.random.split_key())

  def update(self, x):
    return (x @ self.mask) * self.weight

In [None]:
class Exponential(bp.Projection):
  def __init__(self, num_pre, post, prob, g_max, tau, E):
    super().__init__()
    self.proj = bp.dyn.ProjAlignPostMg1(
      comm=MaskedLinear(num_pre, post.num, prob, g_max, sharding=[None, bm.sharding.NEU_AXIS]),
      syn=bp.dyn.Expon.desc(post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS]),
      out=bp.dyn.COBA.desc(E=E),
      post=post
    )

  def update(self, spk):
    spk = bm.asarray(spk, dtype=float)
    self.proj.update(spk)

In [None]:
class COBA(bp.DynSysGroup):
    def __init__(self, scale, monitor=False):
        super().__init__()
        self.monitor = monitor
        self.num_exc = int(3200 * scale)
        self.num_inh = int(800 * scale)
        self.N = LIF(self.num_exc + self.num_inh, V_init=bp.init.Normal(-55., 5.))
        self.E = Exponential(self.num_exc, self.N, prob=80. / self.N.num, E=Erev_exc, g_max=we, tau=taue)
        self.I = Exponential(self.num_inh, self.N, prob=80. / self.N.num, E=Erev_inh, g_max=wi, tau=taui)

    def update(self, inp=Ib):
        self.E(self.N.spike[:self.num_exc])
        self.I(self.N.spike[self.num_exc:])
        self.N(inp)
        if self.monitor:
            return self.N.spike.value

In [None]:
def run_a_simulation2(scale=10, duration=1e3, platform='cpu', x64=True, monitor=False):
  bm.set_platform(platform)
  bm.random.seed()
  if x64:
    bm.enable_x64()

  net = COBA(scale=scale, monitor=monitor)

  @bm.jit
  def run(indices):
    return bm.for_loop(net.step_run, indices, progress_bar=False)

  indices = np.arange(int(duration / bm.get_dt()))
  t0 = time.time()
  r = jax.block_until_ready(run(indices))
  t1 = time.time()
  print(f'first run time = {t1 - t0} s')

  indices = np.arange(int(duration / bm.get_dt()), int(duration / bm.get_dt()) * 2)
  t2 = time.time()
  r = jax.block_until_ready(run(indices))
  t3 = time.time()
  jax.debug.visualize_array_sharding(r)
  print(f'second run time = {t3 - t2} s')

  # running
  if monitor:
    r = jax.device_put(r, jax.devices('cpu')[0])
    r = bm.as_numpy(r)
    print(f'scale={scale}, size={net.num}, first run time = {t1 - t0} s, second run time = {t3 - t2} s, '
          f'firing rate = {r.sum() / net.num / duration * 1e3} Hz')
  else:
    print(f'scale={scale}, size={net.num}, first run time = {t1 - t0} s, second run time = {t3 - t2} s')
  bm.disable_x64()
  bm.clear_buffer_memory(platform)
  return net.N.num, t1 - t0, t3 - t2

In [None]:
with bm.sharding.device_mesh(jax.devices(), [bm.sharding.NEU_AXIS]):
    for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
      run_a_simulation2(scale=s, duration=5e3, platform='tpu', x64=False, monitor=True)