In [1]:
import jax, jax.numpy as jnp

from algorithm.neat import *
from algorithm.neat.genome.advance import AdvanceInitialize
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
from utils.graph import topological_sort_python
from tensorneat.utils import ACT, AGG

import numpy as np

In [32]:
genome = AdvanceInitialize(
    num_inputs=20,
    num_outputs=1,
    hidden_cnt=2,
    max_nodes=30,
    max_conns=50,
    node_gene=NodeGeneWithoutResponse(
        activation_default= ACT.identity,
        aggregation_default=AGG.sum,
        # activation_options=(ACT.tanh, ACT.sigmoid, ACT.identity, ACT.clamped),
        activation_options=( ACT.identity, ),
        aggregation_options=(AGG.sum,),
    ),
    # output_transform=jnp.tanh,
)

state = genome.setup()

randkey = jax.random.PRNGKey(42)
nodes, conns = genome.initialize(state, randkey)

network = genome.network_dict(state, nodes, conns)

In [33]:
import sympy as sp

# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network)
# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh)
symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network)

In [38]:
random_inputs = np.random.randn(20).astype(np.float32)
random_inputs = jax.device_put(random_inputs)
random_inputs, random_inputs.dtype

(Array([-0.10080967, -2.373122  , -0.12224621,  1.0417817 ,  0.26311624,
        -0.04573117,  0.5329444 ,  1.9844177 , -0.5471916 , -3.0961084 ,
         0.07978257, -1.0657575 , -1.6740963 ,  1.2435746 , -0.5811825 ,
         0.8970058 , -0.4379712 ,  0.9084878 , -1.0984142 ,  0.33063456],      dtype=float32),
 dtype('float32'))

In [39]:
transformed = genome.transform(state, nodes, conns)
res = genome.forward(state, transformed, random_inputs)
res

Array([-3.769288], dtype=float32, weak_type=True)

In [40]:
res1 = jnp.array(jax_forward_func(random_inputs))
res2 = genome.forward(state, transformed, random_inputs)
res1, res2

(Array([-3.7692878], dtype=float32),
 Array([-3.769288], dtype=float32, weak_type=True))

In [41]:
all(res1 == res)

False

In [73]:
random_inputs = np.random.randn(1000).astype(np.float32)
random_inputs = jax.device_put(random_inputs)

In [74]:
res1 = 1.243123123 + random_inputs * 1.12413243123123

In [75]:
res2 = random_inputs * 1.12413243123123 + 1.243123123

In [76]:
res1 == res2

Array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,

In [77]:
resres1 = 0.
for i in range(1000):
    resres1 += res1[i]
resres2 = jnp.sum(res2)
resres1, resres2

(Array(1251.1074, dtype=float32), Array(1251.1078, dtype=float32))

In [78]:
resres1 == resres2

Array(False, dtype=bool)

In [66]:
res1, res2

(Array([-0.0913986 ,  0.6177338 ,  3.704111  ,  1.067648  ,  3.5810733 ,
         0.3716032 , -0.10655618,  1.3503847 ,  0.97305036,  0.7711922 ],      dtype=float32),
 Array([-0.0913986 ,  0.6177338 ,  3.704111  ,  1.067648  ,  3.5810733 ,
         0.3716032 , -0.10655618,  1.3503847 ,  0.97305036,  0.7711922 ],      dtype=float32))

In [20]:
real = 10
full = 50000
random_inputs = np.random.randn(real).astype(np.float32)
random_inputs = jax.device_put(random_inputs)

In [21]:
all_nans = jnp.full((full,), jnp.nan)
large = all_nans.at[:real].set(random_inputs)

In [22]:
res1 = jnp.sum(large, where=~jnp.isnan(large))
res1

Array(-5.8886395, dtype=float32)

In [23]:
res2 = jnp.sum(random_inputs)
res2

Array(-5.8886395, dtype=float32)

In [24]:
res1 == res2

Array(True, dtype=bool)