In [1]:
from lib.sde.cell3 import Cell3, CellBatch3
from lib.sde.grn3 import GRNMain3, ParamGene3
import numpy as np
import jax.numpy as jnp

In [2]:
grn = GRNMain3(2, 0, 0)
grn._params = np.array(
    [[0.34922326, 0.438748  ],
     [7.5721583 , 2.001825  ],
     [7.615428  , 8.380326  ],
     [3.8867943 , 7.3130646 ],
     [3.710491  , 8.649075  ],
     [7.5654016 , 0.93269396],
     [6.519226  , 8.221411  ]]
)
# tree must be also set ?
grn.compile()

In [3]:
c = Cell3(grn=grn)

In [4]:
c.__dict__

{'grn': >> G_0: init: 0.35; noise: 7.57; b: 7.62; m: 3.89; expr: 3.71; deg: 7.57; theta: 6.52; tree : 0
 >> G_1: init: 0.44; noise: 2.00; b: 8.38; m: 7.31; expr: 8.65; deg: 0.93; theta: 8.22; tree : (NOT 1 AND 0),
 'quantities': DeviceArray([0.5092498 , 0.07729276], dtype=float32),
 'activation': DeviceArray([0., 0.], dtype=float32),
 'expression': DeviceArray([0., 0.], dtype=float32),
 'derivative': DeviceArray([0., 0.], dtype=float32)}

In [5]:
c.run_step()

In [6]:
c.__dict__

{'grn': >> G_0: init: 0.35; noise: 7.57; b: 7.62; m: 3.89; expr: 3.71; deg: 7.57; theta: 6.52; tree : 0
 >> G_1: init: 0.44; noise: 2.00; b: 8.38; m: 7.31; expr: 8.65; deg: 0.93; theta: 8.22; tree : (NOT 1 AND 0),
 'quantities': DeviceArray([0.43287706, 0.35237017], dtype=float32),
 'activation': DeviceArray([0.7017571, 0.5516028], dtype=float32),
 'expression': DeviceArray([0.7017571, 0.3146659], dtype=float32),
 'derivative': DeviceArray([-1.248816 ,  2.6494787], dtype=float32)}

In [7]:
# run step does something
# noise does something
# expr and deg does something
# init works

In [8]:
def test_cell_step(cell):
    x = cell.quantities.copy()
    cell.run_step()
    assert (x != cell.quantities).any()

test_cell_step(c)

In [9]:
def run_test_cell_param(grn, id_to_check):
    # 1 is noise
    def_quant = jnp.array([0.5, 0.5])
    grn1 = grn.copy()
    grn1._params = np.array(grn1._params)
    grn1._params[1, :] = np.array([0, 0])
    grn1._params[id_to_check, :] = np.array([0, 0])
    
    grn2 = grn.copy()
    grn2._params = np.array(grn2._params)
    grn2._params[1, :] = np.array([0, 0])
    grn2._params[id_to_check, :] = np.array([1, 1])
    
    cell1 = Cell3(grn1)
    cell1.quantities = def_quant
    
    cell2 = Cell3(grn1)
    cell2.quantities = def_quant

    cell3 = Cell3(grn2)
    cell3.quantities = def_quant

    cell1.run_step(), cell2.run_step(), cell3.run_step()
    assert (cell1.quantities == cell2.quantities).all()
    assert not (cell1.quantities == cell3.quantities).all()
    
def test_cell_params(grn):
    for i in range(1, 7):
        run_test_cell_param(grn, 1)
    
test_cell_params(grn)

In [10]:
batch = CellBatch3([c])

In [11]:
c.quantities

DeviceArray([0.31717312, 0.55876344], dtype=float32)

In [12]:
batch.run_step()

In [13]:
c.quantities

DeviceArray([0.06982404, 0.62264025], dtype=float32)

In [14]:
def test_cell_step_batch(cell):
    x = cell.quantities.copy()
    batch = CellBatch3([c])
    batch.run_step()
    assert (x != cell.quantities).any()

test_cell_step_batch(c)

In [16]:
def run_test_cell_param_batch(grn, id_to_check):
    # 1 is noise
    def_quant = jnp.array([0.5, 0.5])
    grn1 = grn.copy()
    grn1._params = np.array(grn1._params)
    grn1._params[1, :] = np.array([0, 0])
    grn1._params[id_to_check, :] = np.array([0, 0])
    
    grn2 = grn.copy()
    grn2._params = np.array(grn2._params)
    grn2._params[1, :] = np.array([0, 0])
    grn2._params[id_to_check, :] = np.array([1, 1])
    
    cell1 = Cell3(grn1)
    cell1.quantities = def_quant
    batch1 = CellBatch3([cell1])
    
    cell2 = Cell3(grn1)
    cell2.quantities = def_quant
    batch2 = CellBatch3([cell2])

    cell3 = Cell3(grn2)
    cell3.quantities = def_quant
    batch3 = CellBatch3([cell3])

    batch1.run_step(), batch2.run_step(), batch3.run_step()
    assert (cell1.quantities == cell2.quantities).all()
    assert not (cell1.quantities == cell3.quantities).all()
    
def test_cell_params_batch(grn):
    for i in range(1, 7):
        run_test_cell_param_batch(grn, 1)
    
test_cell_params(grn)