In [1]:
import evox as ev
from evox.workflows import StdWorkflow
import numpy as np
from model import IzhikevichModel
import neuron as nr
from neuron.units import mV, ms
import pandas as pd

In [4]:
params = pd.read_csv('2-000-1_paramset.csv', names=["k", "a", "b", "d", "C", "Vr", "Vt", "Vmin", "Vp"])

In [None]:
soma = nr.h.Section(name="soma")
soma.insert('hh')
iclamp = nr.h.IClamp(soma(0.5))
iclamp.delay = 10
iclamp.dur = 230

solutions           = [7.57575758e+01, 3.50000000e+01, 9.39393939e+01, 6.56565657e+01,
       3.43434343e-03, 1.51515152e+00, 1.21212121e+00]
soma.L              = solutions[0]
soma.diam           = solutions[1]
soma(0.5).hh.gkbar  = solutions[2]
soma(0.5).hh.gnabar = solutions[3]
soma(0.5).hh.gl     = solutions[4]
soma.cm             = solutions[5]
iclamp.amp          = solutions[6]

v = nr.h.Vector().record(soma(0.5)._ref_v)
t = nr.h.Vector().record(nr.h._ref_t)

nr.h.load_file("stdrun.hoc")
nr.h.finitialize(-65 * mV)
nr.h.continuerun(250 * ms)

: 

In [164]:
help(ev.monitors.StdSOMonitor)

Help on class StdSOMonitor in module evox.monitors.std_so_monitor:

class StdSOMonitor(builtins.object)
 |  StdSOMonitor(record_topk=1, record_fit_history=True, record_pop_history=False)
 |  
 |  Standard single-objective monitor
 |  Used for single-objective workflow,
 |  can monitor fitness and the population.
 |  
 |  Parameters
 |  ----------
 |  record_topk
 |      Control how many elite solutions are recorded.
 |      Default is 1, which will record the best individual.
 |  record_fit_history
 |      Whether to record the full history of fitness value.
 |      Default to True. Setting it to False may reduce memory usage.
 |  
 |  Methods defined here:
 |  
 |  __init__(self, record_topk=1, record_fit_history=True, record_pop_history=False)
 |      Initialize self.  See help(type(self)) for accurate signature.
 |  
 |  close(self)
 |  
 |  flush(self)
 |  
 |  get_best_fitness(self)
 |  
 |  get_best_solution(self)
 |  
 |  get_history(self)
 |  
 |  get_last(self)
 |  
 |  get_to

In [165]:
import jax
import jax.numpy as jnp

algorithm = ev.algorithms.PSO(
    lb=jnp.full(shape=(2,), fill_value=-32),
    ub=jnp.full(shape=(2,), fill_value=32),
    pop_size=100
)
problem = ev.problems.numerical.Ackley()
monitor = ev.monitors.StdSOMonitor(record_topk=10)

workflow = ev.workflows.StdWorkflow(
    algorithm,
    problem,
    monitors=[monitor]
)

key = jax.random.PRNGKey(42)
state = workflow.init(key)

for i in range(100):
    state = workflow.step(state)

In [169]:
f0 = bkplt.figure()
f0.scatter(monitor.get_topk_solutions()[:,0], monitor.get_topk_solutions()[:,1])
bkplt.show(f0)

In [2]:
import jax
import jax.numpy as jnp

In [193]:
from datetime import datetime
import jax.numpy as jnp
t0 = datetime.now()
jnp.sum(jnp.abs(jnp.array([1,2,3]) - jnp.array([2,3,4])))
t1 = datetime.now()
t1-t0

datetime.timedelta(microseconds=1737)

In [3]:
params_limit = {
   'L'      : np.linspace(4,100,100).tolist(), #um
   'diam'   : np.linspace(5,140,100).tolist(), #um
   'gkbar'  : np.linspace(0,100,100).tolist(),
   'gnabar' : np.linspace(0,100,100).tolist(),
   'gl'     : np.linspace(0,.01,100).tolist(),
   'cm'     : np.linspace(0,10,100).tolist(),
   'amp'    : np.linspace(0,10,100).tolist(),
}

In [25]:
ev.problems.numerical.Ackley()._module_name

AttributeError: 'Ackley' object has no attribute '_module_name'

In [217]:
from jax.numpy import std
import numpy as np

center_init=jnp.array([2.0e+01, 2.0e+01, 3.6e-02, 1.2e-01, 3.0e-04, 1.0e+00, 5.0e-01])
lb=jnp.array([4,5,0,0,0,0,0])
ub=jnp.array([100,140,100,100,.01,10,10])

algorithm = ev.algorithms.PSO(
       pop_size=100,
       lb=lb,
       ub=ub
)
tmp = 0
@ev.jit_class
class MatchSignal(ev.Problem):
       def __init__(self) -> None:
              super().__init__()
              self.vsignal = np.fromfile("RS.csv")
              self.soma = nr.h.Section(name="soma")
              self.soma.insert('hh')
              self.iclamp = nr.h.IClamp(self.soma(0.5))
              self.iclamp.delay = 10
              self.iclamp.dur = 230
              nr.h.load_file("stdrun.hoc")
              self.v = nr.h.Vector().record(self.soma(0.5)._ref_v)
              self.t = nr.h.Vector().record(nr.h._ref_t)
       def __fitness_func(self, solution, i):
              print(solution.at[i,0].get())
              self.soma.L              = solution.at[0].get()
              self.soma.diam           = solution.at[1].get()
              self.soma(0.5).hh.gkbar  = solution.at[2].get()
              self.soma(0.5).hh.gnabar = solution.at[3].get()
              self.soma(0.5).hh.gl     = solution.at[4].get()
              self.soma.cm             = solution.at[5].get()
              self.iclamp.amp          = solution.at[6].get()
              nr.h.finitialize(-65 * mV)
              nr.h.continuerun(250 * ms)
              return 1e3-jnp.sum(jnp.abs(self.v.as_numpy() - self.vsignal))
       def evaluate(self, state, solutions):
              global tmp
              tmp = solutions
              fitness = jnp.zeros(solutions.shape[0])
              for i in range(solutions.shape[0]):
                     fitness[i] = self._fitness_func(solutions,i)
              return fitness, state
problem = MatchSignal()
monitor = ev.monitors.StdSOMonitor(record_topk=10)

workflow = ev.workflows.StdWorkflow(
       algorithm,
       problem,
       monitors=[monitor]
)

In [17]:
import jax
import jax.numpy as jnp
import scipy.special
@jax.jit
def jv(v, z):
    v, z = jnp.asarray(v), jnp.asarray(z)
    print(v)
    assert jnp.issubdtype(v.dtype, jnp.integer)
    z = z.astype(jnp.result_type(float, z.dtype))

    _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)

    result_shape_dtype = jax.ShapeDtypeStruct(
        shape=jnp.broadcast_shapes(v.shape,z.shape),
        dtype=z.dtype
    )
    return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

In [23]:
from functools import partial
j1 = partial(jv, 1)
z = jnp.arange(5.0)
%timeit j1(z)

954 µs ± 42.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [3]:
help(jnp.result_type)

Help on function result_type in module jax._src.numpy.lax_numpy:

result_type(*args: 'Any') -> 'DType'
    Returns the type that results from applying the NumPy
    
    LAX-backend implementation of :func:`numpy.result_type`.
    
    *Original docstring below.*
    
    type promotion rules to the arguments.
    
    Type promotion in NumPy works similarly to the rules in languages
    like C++, with some slight differences.  When both scalars and
    arrays are used, the array's type takes precedence and the actual value
    of the scalar is taken into account.
    
    For example, calculating 3*a, where a is an array of 32-bit floats,
    intuitively should result in a 32-bit float output.  If the 3 is a
    32-bit integer, the NumPy rules indicate it can't convert losslessly
    into a 32-bit float, so a 64-bit float should be the result type.
    By examining the value of the constant, '3', we see that it fits in
    an 8-bit integer, which can be cast losslessly into the 32-bit

In [11]:
import jax
import numpy as np
import jax.numpy as jnp
@jax.jit
def f(x):
    x = jnp.array(x)
    tmp = x
    print(tmp)
    return tmp
x=np.random.randn(3,4)
f(f(x))

Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>


Array([[ 0.94180053, -1.1163607 , -0.14706208, -1.8684183 ],
       [ 0.87468606,  1.1291984 ,  2.2395656 , -2.4310803 ],
       [ 0.66759086, -1.4921861 ,  1.3350419 ,  1.0678979 ]],      dtype=float32)

In [232]:
help(jax.pure_callback)

Help on function pure_callback_api in module jax._src.callback:

pure_callback_api(callback: 'Callable[..., Any]', result_shape_dtypes: 'Any', *args: 'Any', sharding: 'SingleDeviceSharding | None' = None, vectorized: 'bool' = False, **kwargs: 'Any')
    Applies a functionally pure Python callable. Works under :func:`jit`/:func:`~pmap`/etc.
    
    ``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
    The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
    should also return NumPy arrays. Execution takes place on CPU, like any
    Python+NumPy function.
    
    The callback is treated as functionally pure, meaning it has no side-effects
    and its output value depends only on its argument values. As a consequence, it
    is safe to be called multiple times (e.g. when transformed by :func:`~vmap` or
    :func:`~pmap`), or not to be called at all when e.g. the output of a
    `jit`-decorated function has no data dependence on its val

In [219]:
key = jax.random.PRNGKey(42)
state = workflow.init(key)

In [220]:
for i in range(10):
    state = workflow.step(state)

Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>


ValueError: L must be > 0.

In [27]:
params_limit = {
   'L'      : np.linspace(4,100,100).tolist(), #um
   'diam'   : np.linspace(5,140,100).tolist(), #um
   'gkbar'  : np.linspace(0,100,100).tolist(),
   'gnabar' : np.linspace(0,100,100).tolist(),
   'gl'     : np.linspace(0,.01,100).tolist(),
   'cm'     : np.linspace(0,10,100).tolist(),
   'amp'    : np.linspace(0,10,100).tolist()
}

In [None]:
from bokeh.io import output_notebook
import bokeh.plotting as bkplt
output_notebook()

: 

In [None]:
f = bkplt.figure(x_axis_label='t (ms)', y_axis_label='v (mV)')
f.line(t.as_numpy(), v.as_numpy(), line_width=3)
bkplt.show(f)

: 

In [88]:
vsignal = 

['T',
 'all',
 'any',
 'argmax',
 'argmin',
 'argpartition',
 'argsort',
 'astype',
 'base',
 'byteswap',
 'choose',
 'clip',
 'compress',
 'conj',
 'conjugate',
 'copy',
 'ctypes',
 'cumprod',
 'cumsum',
 'data',
 'diagonal',
 'dot',
 'dtype',
 'dump',
 'dumps',
 'fill',
 'flags',
 'flat',
 'flatten',
 'getfield',
 'imag',
 'item',
 'itemset',
 'itemsize',
 'max',
 'mean',
 'min',
 'nbytes',
 'ndim',
 'newbyteorder',
 'nonzero',
 'partition',
 'prod',
 'ptp',
 'put',
 'ravel',
 'real',
 'repeat',
 'reshape',
 'resize',
 'round',
 'searchsorted',
 'setfield',
 'setflags',
 'shape',
 'size',
 'sort',
 'squeeze',
 'std',
 'strides',
 'sum',
 'swapaxes',
 'take',
 'tobytes',
 'tofile',
 'tolist',
 'tostring',
 'trace',
 'transpose',
 'var',
 'view']

In [92]:
v.as_numpy().tofile('RS.csv', format='csv')

In [97]:
help(np.fromstring)

Help on built-in function fromstring in module numpy:

fromstring(...)
    fromstring(string, dtype=float, count=-1, *, sep, like=None)
    
    A new 1-D array initialized from text data in a string.
    
    Parameters
    ----------
    string : str
        A string containing the data.
    dtype : data-type, optional
        The data type of the array; default: float.  For binary input data,
        the data must be in exactly this format. Most builtin numeric types are
        supported and extension types may be supported.
    
        .. versionadded:: 1.18.0
            Complex dtypes.
    
    count : int, optional
        Read this number of `dtype` elements from the data.  If this is
        negative (the default), the count will be determined from the
        length of the data.
    sep : str, optional
        The string separating numbers in the data; extra whitespace between
        elements is also ignored.
    
        .. deprecated:: 1.14
            Passing ``sep=''``

In [110]:
sols_string = "[[2.00000000e+01 2.00000000e+01 3.60000000e-02 1.20000000e-01 3.00000000e-04 1.00000000e+00 5.00000000e-01] [7.96363636e+01 8.00000000e+01 3.60000000e-02 1.20000000e-01 6.96969697e-03 6.46464646e+00 5.05050505e-01]\
 [3.89090909e+01 6.63636364e+01 3.60000000e-02 1.20000000e-01\
  5.55555556e-03 5.45454545e+00 5.00000000e-01]\
 [6.89696970e+01 4.86363636e+01 3.60000000e-02 1.20000000e-01\
  6.26262626e-03 5.45454545e+00 5.00000000e-01]\
 [4.56969697e+01 8.00000000e+01 3.60000000e-02 1.20000000e-01\
  8.88888889e-03 8.68686869e+00 4.04040404e-01]\
 [8.15757576e+01 4.31818182e+01 3.60000000e-02 1.20000000e-01\
  4.14141414e-03 9.39393939e+00 2.42424242e+00]\
 [5.34545455e+01 8.54545455e+01 3.60000000e-02 1.20000000e-01\
  9.89898990e-03 8.38383838e+00 5.00000000e-01]\
 [5.05454545e+01 3.22727273e+01 3.60000000e-02 1.20000000e-01\
  9.39393939e-03 9.19191919e+00 5.00000000e-01]\
 [8.44848485e+01 9.63636364e+01 3.60000000e-02 1.20000000e-01\
  8.28282828e-03 6.16161616e+00 5.05050505e+00]\
 [4.56969697e+01 1.11363636e+02 3.60000000e-02 1.20000000e-01\
  6.66666667e-03 7.37373737e+00 4.64646465e+00]\
 [4.95757576e+01 4.86363636e+01 3.60000000e-02 1.20000000e-01\
  8.48484848e-03 8.08080808e-01 5.00000000e-01]\
 [4.27878788e+01 3.50000000e+01 3.60000000e-02 1.20000000e-01\
  8.58585859e-03 1.31313131e+00 5.00000000e-01]\
 [5.44242424e+01 6.09090909e+01 3.60000000e-02 1.20000000e-01\
  6.96969697e-03 9.39393939e+00 4.54545455e+00]\
 [1.17575758e+01 8.54545455e+01 3.60000000e-02 1.20000000e-01\
  9.59595960e-03 7.07070707e-01 9.09090909e-01]\
 [4.27878788e+01 4.86363636e+01 3.60000000e-02 1.20000000e-01\
  8.08080808e-04 0.00000000e+00 5.05050505e-01]\
 [4.56969697e+01 8.00000000e+01 3.60000000e-02 1.20000000e-01\
  9.29292929e-03 4.54545455e+00 5.75757576e+00]\
 [5.05454545e+01 5.68181818e+01 3.60000000e-02 1.20000000e-01\
  4.14141414e-03 9.39393939e+00 5.15151515e+00]\
 [5.44242424e+01 6.63636364e+01 3.60000000e-02 1.20000000e-01\
  6.56565657e-03 5.45454545e+00 8.08080808e+00]\
 [1.85454545e+01 8.54545455e+01 3.60000000e-02 1.20000000e-01\
  8.48484848e-03 1.00000000e+00 4.64646465e+00]\
 [2.00000000e+01 1.05909091e+02 3.60000000e-02 1.20000000e-01\
  8.98989899e-03 9.59595960e+00 7.87878788e+00]]"

In [119]:
sols = np.fromstring(sols_string.replace("[", "").replace("]", ""), sep=" ").reshape((20,7))

In [133]:
sols[14]

array([4.27878788e+01, 4.86363636e+01, 3.60000000e-02, 1.20000000e-01,
       8.08080808e-04, 0.00000000e+00, 5.05050505e-01])

In [None]:
solution = np.zeros(3)
h = HodgkinHuxleyModel()
h.dur = duration
h.delay = delay
h.amp = abs(solution[0])
h.area = abs(solution[1])
h.cm = abs(solution[2])
h.Initialize(-65*mV)
h.ContinueRun(250*ms)

In [None]:
import jax.numpy as np
from neuron.units import ms, mV
import pygad
import json
import os

spike_type = 'FS_NoAdaptation'
#units: [v] = [v_inf] = mV, [tau] = ms, [g] = uS/mm2, [A] = mm2, [i_ext] = nA, [c] = nF/mm2

delay = 10
duration = 230

with open("Izhikevich_Model_Params.json") as json_file:
    types_params = json.load(json_file)
iz = IzhikevichModel(types_params[spike_type])
iz.dur = duration
iz.delay = delay
iz.Initialize(-65*mV)
iz.ContinueRun(250*ms)

solution = np.zeros(3)
def fitness_func(ga_instance, solution, solution_idx):
  h = HodgkinHuxleyModel()
  h.dur = duration
  h.delay = delay
  h.amp = abs(solution[0])
  h.area = abs(solution[1])
  h.cm = abs(solution[2])
  h.Initialize(-65*mV)
  h.ContinueRun(250*ms)
  vs_hh = h.ys[:,0]
  fitness = 1.0/(0.001+np.sum(np.abs(vs_hh - iz.vs)))
  return fitness
    
num_generations = 1
num_parents_mating = 6

fitness_function = fitness_func

sol_per_pop = 10
num_genes = len(solution)

init_range_low = 0
init_range_high = 100

parent_selection_type = "sss"
keep_parents = 1

crossover_type = "two_points"

mutation_type = "random"

def on_gen(ga_instance):
    ga_instance.best_solution()

try:
  ga_instance = pygad.load(filename=spike_type)
except FileNotFoundError:
  print("creating a new instance \n")
  ga_instance = pygad.GA(num_generations=num_generations,
                        num_parents_mating=num_parents_mating,
                        fitness_func=fitness_function,
                        sol_per_pop=sol_per_pop,
                        num_genes=num_genes,
                        init_range_low=init_range_low,
                        init_range_high=init_range_high,
                        parent_selection_type=parent_selection_type,
                        keep_parents=keep_parents,
                        crossover_type=crossover_type,
                        mutation_by_replacement=True,
                        mutation_type=mutation_type,
                        mutation_num_genes=2,
                        random_mutation_min_val=0,
                        random_mutation_max_val=100)
ga_instance.run()

if not os.path.exists('log.txt'):
  log_file = open('log.txt', 'a')
  log_file.write("|\tno. of generations\t|\tminimum fitness \t|\t50th percentile fitness\t|\tmaximum fitness \t|\n")
  log_file.write("-----------------------------------------------------------------------------------------------------------------------\n")
else:
  log_file = open('log.txt', 'a')
fitnesses = ga_instance.last_generation_fitness
log_file.write("|\t\t{}\t\t|\t{:.2E}\t\t|\t{:.2E}\t\t|\t{:.2E}\t\t|\n".format(ga_instance.generations_completed, fitnesses.min(), np.percentile(fitnesses, 50), fitnesses.max()))
log_file.close()

while ga_instance.last_generation_fitness.mean() < 0.1 :
  ga_instance.run()
  fitnesses = ga_instance.last_generation_fitness
  log_file = open('log.txt', 'a')
  log_file.write("|\t\t{}\t\t|\t{:.2E}\t\t|\t{:.2E}\t\t|\t{:.2E}\t\t|\n".format(ga_instance.generations_completed, fitnesses.min(), np.percentile(fitnesses, 50), fitnesses.max()))
  log_file.close()
  ga_instance.save(filename=spike_type)

In [None]:
pso = algorithms.PSO(
    lb=jnp.full(shape=(2,), fill_value=-32),
    ub=jnp.full(shape=(2,), fill_value=32),
    pop_size=100,
)
ackley = problems.numerical.Ackley()

In [None]:
monitor = monitors.StdSOMonitor(record_topk=50)
workflow = workflows.StdWorkflow(
    pso,
    ackley,
    monitor=monitor,
    record_pop=True)

In [None]:
key = random.PRNGKey(42)
state = workflow.init(key)

In [None]:
# run the workflow for 100 steps
for i in range(1000):
    state = workflow.step(state)

In [None]:
monitor.flush()
solutions = monitor.get_topk_solutions()

In [None]:
monitor.get_topk_fitness()

In [None]:
sols = solutions.reshape((2,50))
plt.scatter(sols[0], sols[1])

In [None]:
solutions.shape

In [None]:
norm_sols = jnp.sqrt(jnp.square(solutions[:][0]) + jnp.square(solutions[:][1]**2))

In [None]:
solutions.reshape((2,50))

In [None]:
pso.ask(state)