In [1]:
!pip install jax[tpu]==0.4.11 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
Collecting jax[tpu]==0.4.11
  Downloading jax-0.4.11.tar.gz (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting jaxlib==0.4.11
  Downloading jaxlib-0.4.11-cp38-cp38-manylinux2014_x86_64.whl (71.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.1/71.1 MB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting libtpu-nightly==0.1.dev20230531
  Downloading https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20230531-py3-none-any.whl (170.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m170.7/170.7 MB[0m [31m2.4 MB/s[0m et

In [2]:
import jax
print(jax.__version__)
print(jax.devices())

0.4.11
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]


In [3]:
!pip install brainpy

Collecting brainpy
  Downloading brainpy-2.4.5-py3-none-any.whl (669 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m669.2/669.2 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: brainpy
Successfully installed brainpy-2.4.5
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [4]:
import time
from pprint import pprint  
import numpy as np
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
bp.check.turn_off()

In [7]:
class Tool:
  def __init__(self, pre_stimulus_period=100., stimulus_period=1000., delay_period=500.):
    self.pre_stimulus_period = pre_stimulus_period
    self.stimulus_period = stimulus_period
    self.delay_period = delay_period
    self.freq_variance = 10.
    self.freq_interval = 50.
    self.total_period = pre_stimulus_period + stimulus_period + delay_period

  def generate_freqs(self, mean, num=1):
    # stimulus period
    n_stim = int(self.stimulus_period / self.freq_interval)
    n_interval = int(self.freq_interval / bm.get_dt())
    freqs_stim = np.random.normal(mean, self.freq_variance, (n_stim, 1, num))
    freqs_stim = np.tile(freqs_stim, (1, n_interval, 1)).reshape(n_stim * n_interval, num)
    # pre stimulus period
    freqs_pre = np.zeros([int(self.pre_stimulus_period / bm.get_dt()), num])
    # post stimulus period
    freqs_delay = np.zeros([int(self.delay_period / bm.get_dt()), num])
    all_freqs = np.concatenate([freqs_pre, freqs_stim, freqs_delay], axis=0)
    return bm.asarray(all_freqs)

  def visualize_results(self, mon, IA_freqs, IB_freqs, t_start=0., title=None):
    fig, gs = bp.visualize.get_figure(4, 1, 3, 10)
    axes = [fig.add_subplot(gs[i, 0]) for i in range(4)]

    ax = axes[0]
    bp.visualize.raster_plot(mon['ts'], mon['A.spike'], markersize=1, ax=ax)
    if title: ax.set_title(title)
    ax.set_ylabel("Group A")
    ax.set_xlim(t_start, self.total_period + 1)
    ax.axvline(self.pre_stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period + self.delay_period, linestyle='dashed')

    ax = axes[1]
    bp.visualize.raster_plot(mon['ts'], mon['B.spike'], markersize=1, ax=ax)
    ax.set_ylabel("Group B")
    ax.set_xlim(t_start, self.total_period + 1)
    ax.axvline(self.pre_stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period + self.delay_period, linestyle='dashed')

    ax = axes[2]
    rateA = bp.measure.firing_rate(mon['A.spike'], width=10.)
    rateB = bp.measure.firing_rate(mon['B.spike'], width=10.)
    ax.plot(mon['ts'], rateA, label="Group A")
    ax.plot(mon['ts'], rateB, label="Group B")
    ax.set_ylabel('Population activity [Hz]')
    ax.set_xlim(t_start, self.total_period + 1)
    ax.axvline(self.pre_stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period + self.delay_period, linestyle='dashed')
    ax.legend()

    ax = axes[3]
    ax.plot(mon['ts'], IA_freqs, label="group A")
    ax.plot(mon['ts'], IB_freqs, label="group B")
    ax.set_ylabel("Input activity [Hz]")
    ax.set_xlim(t_start, self.total_period + 1)
    ax.axvline(self.pre_stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period + self.delay_period, linestyle='dashed')
    ax.legend()
    ax.set_xlabel("Time [ms]")

    plt.show()

In [8]:
class ExpSyn(bp.Projection):
  def __init__(self, pre, post, conn, delay, g_max, tau, E):
    super().__init__()
    if conn == 'all2all':
      comm = bp.dnn.AllToAll(pre.num, post.num, g_max)
    elif conn == 'one2one':
      comm = bp.dnn.OneToOne(pre.num, g_max)
    else:
      raise ValueError
    syn = bp.dyn.Expon.desc(post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS])
    out = bp.dyn.COBA.desc(E=E)
    self.proj = bp.dyn.ProjAlignPostMg2(
      pre=pre, delay=delay, comm=comm,
      syn=syn, out=out, post=post
    )

In [9]:
class NMDA(bp.Projection):
  def __init__(self, pre, post, conn, delay, g_max):
    super().__init__()
    if conn == 'all2all':
      comm = bp.dnn.AllToAll(pre.num, post.num, g_max)
    elif conn == 'one2one':
      comm = bp.dnn.OneToOne(pre.num, g_max)
    else:
      raise ValueError
    syn = bp.dyn.NMDA.desc(pre.num, a=0.5, tau_decay=100., tau_rise=2., sharding=[bm.sharding.NEU_AXIS])
    out = bp.dyn.MgBlock(E=0., cc_Mg=1.0)
    self.proj = bp.dyn.ProjAlignPreMg2(
      pre=pre, delay=delay, syn=syn,
      comm=comm, out=out, post=post
    )

In [10]:
class DecisionMakingNet(bp.DynSysGroup):
  def __init__(self, scale=1., f=0.15):
    super().__init__()
    # 网络中各组神经元的数目
    num_exc = int(1600 * scale)
    num_I, num_A, num_B = int(400 * scale), int(f * num_exc), int(f * num_exc)
    num_N = num_exc - num_A - num_B
    self.num_A, self.num_B, self.num_N, self.num_I = num_A, num_B, num_N, num_I
    self.num = num_A + num_B + num_N + num_I

    poisson_freq = 2400.  # Hz
    w_pos = 1.7
    w_neg = 1. - f * (w_pos - 1.) / (1. - f)
    g_ext2E_AMPA = 2.1  # nS
    g_ext2I_AMPA = 1.62  # nS
    g_E2E_AMPA = 0.05 / scale  # nS
    g_E2I_AMPA = 0.04 / scale  # nS
    g_E2E_NMDA = 0.165 / scale  # nS
    g_E2I_NMDA = 0.13 / scale  # nS
    g_I2E_GABAa = 1.3 / scale  # nS
    g_I2I_GABAa = 1.0 / scale  # nS

    neu_par = dict(V_rest=-70., V_reset=-55., V_th=-50., V_initializer=bp.init.OneInit(-70.),
                   sharding=[bm.sharding.NEU_AXIS])

    # E neurons/pyramid neurons
    self.A = bp.dyn.LifRef(num_A, tau=20., R=0.04, tau_ref=2., **neu_par)
    self.B = bp.dyn.LifRef(num_B, tau=20., R=0.04, tau_ref=2., **neu_par)
    self.N = bp.dyn.LifRef(num_N, tau=20., R=0.04, tau_ref=2., **neu_par)

    # I neurons/interneurons
    self.I = bp.dyn.LifRef(num_I, tau=10., R=0.05, tau_ref=1., **neu_par)

    # poisson stimulus  # 'freqs' as bm.Variable
    self.IA = bp.dyn.PoissonGroup(num_A, freqs=bm.Variable(bm.zeros(1)), sharding=[bm.sharding.NEU_AXIS])
    self.IB = bp.dyn.PoissonGroup(num_B, freqs=bm.Variable(bm.zeros(1)), sharding=[bm.sharding.NEU_AXIS])

    # noise neurons
    self.noise_B = bp.dyn.PoissonGroup(num_B, freqs=poisson_freq, sharding=[bm.sharding.NEU_AXIS])
    self.noise_A = bp.dyn.PoissonGroup(num_A, freqs=poisson_freq, sharding=[bm.sharding.NEU_AXIS])
    self.noise_N = bp.dyn.PoissonGroup(num_N, freqs=poisson_freq, sharding=[bm.sharding.NEU_AXIS])
    self.noise_I = bp.dyn.PoissonGroup(num_I, freqs=poisson_freq, sharding=[bm.sharding.NEU_AXIS])

    # define external inputs
    self.IA2A = ExpSyn(self.IA, self.A, 'one2one', None, g_ext2E_AMPA, tau=2., E=0.)
    self.IB2B = ExpSyn(self.IB, self.B, 'one2one', None, g_ext2E_AMPA, tau=2., E=0.)

    # define AMPA projections from N
    self.N2B_AMPA = ExpSyn(self.N, self.B, 'all2all', 0.5, g_E2E_AMPA * w_neg, tau=2., E=0.)
    self.N2A_AMPA = ExpSyn(self.N, self.A, 'all2all', 0.5, g_E2E_AMPA * w_neg, tau=2., E=0.)
    self.N2N_AMPA = ExpSyn(self.N, self.N, 'all2all', 0.5, g_E2E_AMPA, tau=2., E=0.)
    self.N2I_AMPA = ExpSyn(self.N, self.I, 'all2all', 0.5, g_E2I_AMPA, tau=2., E=0.)

    # define NMDA projections from N
    self.N2B_NMDA = NMDA(self.N, self.B, 'all2all', 0.5, g_E2E_NMDA * w_neg)
    self.N2A_NMDA = NMDA(self.N, self.A, 'all2all', 0.5, g_E2E_NMDA * w_neg)
    self.N2N_NMDA = NMDA(self.N, self.N, 'all2all', 0.5, g_E2E_NMDA)
    self.N2I_NMDA = NMDA(self.N, self.I, 'all2all', 0.5, g_E2I_NMDA)

    # define AMPA projections from B
    self.B2B_AMPA = ExpSyn(self.B, self.B, 'all2all', 0.5, g_E2E_AMPA * w_pos, tau=2., E=0.)
    self.B2A_AMPA = ExpSyn(self.B, self.A, 'all2all', 0.5, g_E2E_AMPA * w_neg, tau=2., E=0.)
    self.B2N_AMPA = ExpSyn(self.B, self.N, 'all2all', 0.5, g_E2E_AMPA, tau=2., E=0.)
    self.B2I_AMPA = ExpSyn(self.B, self.I, 'all2all', 0.5, g_E2I_AMPA, tau=2., E=0.)

    # define NMDA projections from B
    self.B2B_NMDA = NMDA(self.B, self.B, 'all2all', 0.5, g_E2E_NMDA * w_pos)
    self.B2A_NMDA = NMDA(self.B, self.A, 'all2all', 0.5, g_E2E_NMDA * w_neg)
    self.B2N_NMDA = NMDA(self.B, self.N, 'all2all', 0.5, g_E2E_NMDA)
    self.B2I_NMDA = NMDA(self.B, self.I, 'all2all', 0.5, g_E2I_NMDA)

    # define AMPA projections from A
    self.A2B_AMPA = ExpSyn(self.A, self.B, 'all2all', 0.5, g_E2E_AMPA * w_neg, tau=2., E=0.)
    self.A2A_AMPA = ExpSyn(self.A, self.A, 'all2all', 0.5, g_E2E_AMPA * w_pos, tau=2., E=0.)
    self.A2N_AMPA = ExpSyn(self.A, self.N, 'all2all', 0.5, g_E2E_AMPA, tau=2., E=0.)
    self.A2I_AMPA = ExpSyn(self.A, self.I, 'all2all', 0.5, g_E2I_AMPA, tau=2., E=0.)

    # define NMDA projections from A
    self.A2B_NMDA = NMDA(self.A, self.B, 'all2all', 0.5, g_E2E_NMDA * w_neg)
    self.A2A_NMDA = NMDA(self.A, self.A, 'all2all', 0.5, g_E2E_NMDA * w_pos)
    self.A2N_NMDA = NMDA(self.A, self.N, 'all2all', 0.5, g_E2E_NMDA)
    self.A2I_NMDA = NMDA(self.A, self.I, 'all2all', 0.5, g_E2I_NMDA)

    # define I->E/I conn
    self.I2B = ExpSyn(self.I, self.B, 'all2all', 0.5, g_I2E_GABAa, tau=5., E=-70.)
    self.I2A = ExpSyn(self.I, self.A, 'all2all', 0.5, g_I2E_GABAa, tau=5., E=-70.)
    self.I2N = ExpSyn(self.I, self.N, 'all2all', 0.5, g_I2E_GABAa, tau=5., E=-70.)
    self.I2I = ExpSyn(self.I, self.I, 'all2all', 0.5, g_I2I_GABAa, tau=5., E=-70.)

    # define external projections
    self.noise2B = ExpSyn(self.noise_B, self.B, 'one2one', None, g_ext2E_AMPA, tau=2., E=0.)
    self.noise2A = ExpSyn(self.noise_A, self.A, 'one2one', None, g_ext2E_AMPA, tau=2., E=0.)
    self.noise2N = ExpSyn(self.noise_N, self.N, 'one2one', None, g_ext2E_AMPA, tau=2., E=0.)
    self.noise2I = ExpSyn(self.noise_I, self.I, 'one2one', None, g_ext2I_AMPA, tau=2., E=0.)

In [11]:
def run_a_simulation(scale):
  # simulation tools
  pre_stimulus_period = 100.  # time before the external simuli are given
  stimulus_period = 1000.  # time within which the external simuli are given
  delay_period = 500.  # time after the external simuli are removed
  tool = Tool(pre_stimulus_period, stimulus_period, delay_period)

  # stimulus
  mu0 = 40.
  coherence = 25.6
  IA_freqs = tool.generate_freqs(mu0 + mu0 / 100. * coherence)
  IB_freqs = tool.generate_freqs(mu0 - mu0 / 100. * coherence)

  # network
  net = DecisionMakingNet(scale)

  def run(i):
    bp.share.save(i=i, t=bm.get_dt() * i)
    net.IA.freqs.value = IA_freqs[i]
    net.IB.freqs.value = IB_freqs[i]
    net.update()
    return {'A.spike': net.A.spike.value, 'B.spike': net.B.spike.value}

  # running
  indices = np.arange(int(tool.total_period / bm.get_dt()))
  mon = bm.for_loop(run, indices)
  mon['ts'] = indices * bm.get_dt()
  tool.visualize_results(mon, IA_freqs, IB_freqs)


In [12]:
def simulate_a_trial(scale, platform='cpu', x64=False, monitor=True):
  bm.set_platform(platform)
  bm.random.seed()
  if x64:
    bm.enable_x64()

  # simulation tools
  pre_stimulus_period = 100.  # time before the external simuli are given
  stimulus_period = 1000.  # time within which the external simuli are given
  delay_period = 500.  # time after the external simuli are removed
  tool = Tool(pre_stimulus_period, stimulus_period, delay_period)

  # stimulus
  mu0 = 40.
  coherence = 25.6
  IA_freqs = tool.generate_freqs(mu0 + mu0 / 100. * coherence)
  IB_freqs = tool.generate_freqs(mu0 - mu0 / 100. * coherence)

  # network
  net = DecisionMakingNet(scale)

  def run(i):
    bp.share.save(i=i, t=bm.get_dt() * i)
    net.IA.freqs.value = IA_freqs[i]
    net.IB.freqs.value = IB_freqs[i]
    net.update()
    if monitor:
      return {'A.spike': net.A.spike.value, 'B.spike': net.B.spike.value}

  @bm.jit
  def jit_run(indices):
    return bm.for_loop(run, indices)

  # first running
  n_step = int(tool.total_period / bm.get_dt())
  indices = np.arange(n_step)
#   net.reset_state()
  t0 = time.time()
  mon = jax.block_until_ready(jit_run(indices))
  t1 = time.time()

#   # second running
#   net.reset_state()
#   t2 = time.time()
#   mon = jax.block_until_ready(jit_run(indices))
#   t3 = time.time()

  # mon['ts'] = indices * bm.get_dt()
  # tool.visualize_results(mon, IA_freqs, IB_freqs)

  print(f'platform = {platform}, x64 = {x64}, scale = {scale}, '
        f'first run = {t1 - t0} s')

  # post
  bm.disable_x64()
  bm.clear_buffer_memory(platform)
  return {'num': net.num,
          'exe_time': t1 - t0,
          'run_time': t1 - t0,
          # 'fr': rate
          }

In [13]:
def benchmark(devices, platform='cpu', x64=True):
    for scale in [1, 4, 8, 10, 20, 40, 60, 80, 100]:  # 
        res = {'exetime': [], 'runtime': []}
        for _ in range(10):
            with bm.sharding.device_mesh(devices, [bm.sharding.NEU_AXIS]):
                r = simulate_a_trial(scale=scale, platform=platform, x64=x64, monitor=True)
            res['exetime'].append(r['exe_time'])
            res['runtime'].append(r['run_time'])
        print(f'Scale = {scale}:')
        pprint(res)

# 8 devices

In [15]:
benchmark(jax.devices(), platform='tpu', x64=False)

platform = tpu, x64 = False, scale = 1, first run = 14.618347406387329 s
platform = tpu, x64 = False, scale = 1, first run = 14.200562238693237 s
platform = tpu, x64 = False, scale = 1, first run = 14.191334247589111 s
platform = tpu, x64 = False, scale = 1, first run = 14.406397819519043 s
platform = tpu, x64 = False, scale = 1, first run = 14.399168491363525 s
platform = tpu, x64 = False, scale = 1, first run = 14.504820346832275 s
platform = tpu, x64 = False, scale = 1, first run = 14.53603482246399 s
platform = tpu, x64 = False, scale = 1, first run = 14.44247055053711 s
platform = tpu, x64 = False, scale = 1, first run = 14.480181217193604 s
platform = tpu, x64 = False, scale = 1, first run = 14.6030113697052 s
Scale = 1:
{'exetime': [14.618347406387329,
             14.200562238693237,
             14.191334247589111,
             14.406397819519043,
             14.399168491363525,
             14.504820346832275,
             14.53603482246399,
             14.44247055053711,
 

In [16]:
def benchmark2(devices, scales, platform='cpu', x64=True):
    final_results = dict()
    for scale in scales:  # 
        res = {'exetime': [], 'runtime': []}
        for _ in range(10):
            with bm.sharding.device_mesh(devices, [bm.sharding.NEU_AXIS]):
                r = simulate_a_trial(scale=scale, platform=platform, x64=x64, monitor=True)
            res['exetime'].append(r['exe_time'])
            res['runtime'].append(r['run_time'])
        print(f'Scale = {scale}:')
        pprint(res)

In [19]:
benchmark2(jax.devices(), [200, 400, 800, 1000], platform='tpu', x64=False)

platform = tpu, x64 = False, scale = 200, first run = 17.05660915374756 s
platform = tpu, x64 = False, scale = 200, first run = 17.01578426361084 s
platform = tpu, x64 = False, scale = 200, first run = 17.0588960647583 s
platform = tpu, x64 = False, scale = 200, first run = 17.122342348098755 s
platform = tpu, x64 = False, scale = 200, first run = 18.164459466934204 s
platform = tpu, x64 = False, scale = 200, first run = 17.087402820587158 s
platform = tpu, x64 = False, scale = 200, first run = 17.13105869293213 s
platform = tpu, x64 = False, scale = 200, first run = 17.09291696548462 s
platform = tpu, x64 = False, scale = 200, first run = 17.226580381393433 s
platform = tpu, x64 = False, scale = 200, first run = 18.08594560623169 s
Scale = 200:
{'exetime': [17.05660915374756,
             17.01578426361084,
             17.0588960647583,
             17.122342348098755,
             18.164459466934204,
             17.087402820587158,
             17.13105869293213,
             17.09

In [14]:
benchmark(jax.devices(), platform='tpu', x64=True)

platform = tpu, x64 = True, scale = 1, first run = 27.850502014160156 s
platform = tpu, x64 = True, scale = 1, first run = 29.47843337059021 s
platform = tpu, x64 = True, scale = 1, first run = 30.392443895339966 s
platform = tpu, x64 = True, scale = 1, first run = 30.546841859817505 s
platform = tpu, x64 = True, scale = 1, first run = 30.813852310180664 s
platform = tpu, x64 = True, scale = 1, first run = 30.806926250457764 s
platform = tpu, x64 = True, scale = 1, first run = 31.28420639038086 s
platform = tpu, x64 = True, scale = 1, first run = 31.52595543861389 s
platform = tpu, x64 = True, scale = 1, first run = 31.42830181121826 s
platform = tpu, x64 = True, scale = 1, first run = 31.265384197235107 s
Scale = 1:
{'exetime': [27.850502014160156,
             29.47843337059021,
             30.392443895339966,
             30.546841859817505,
             30.813852310180664,
             30.806926250457764,
             31.28420639038086,
             31.52595543861389,
            

In [17]:
benchmark2(jax.devices(), [200, 400, 800, 1000], platform='tpu', x64=True)

platform = tpu, x64 = True, scale = 200, first run = 40.15212893486023 s
platform = tpu, x64 = True, scale = 200, first run = 41.32249903678894 s
platform = tpu, x64 = True, scale = 200, first run = 42.35366463661194 s
platform = tpu, x64 = True, scale = 200, first run = 41.226806640625 s
platform = tpu, x64 = True, scale = 200, first run = 41.76349496841431 s
platform = tpu, x64 = True, scale = 200, first run = 48.265284061431885 s
platform = tpu, x64 = True, scale = 200, first run = 47.717525482177734 s
platform = tpu, x64 = True, scale = 200, first run = 48.328460454940796 s
platform = tpu, x64 = True, scale = 200, first run = 47.13247036933899 s
platform = tpu, x64 = True, scale = 200, first run = 43.94406318664551 s
Scale = 200:
{'exetime': [40.15212893486023,
             41.32249903678894,
             42.35366463661194,
             41.226806640625,
             41.76349496841431,
             48.265284061431885,
             47.717525482177734,
             48.328460454940796,