In [1]:
import jax, jax.numpy as jnp

from algorithm.neat import *
from algorithm.neat.genome.advance import AdvanceInitialize
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
from utils.graph import topological_sort_python
from utils import Act, Agg

import numpy as np

In [2]:
genome = AdvanceInitialize(
    num_inputs=3,
    num_outputs=3,
    hidden_cnt=2,
    max_nodes=50,
    max_conns=500,
    node_gene=NodeGeneWithoutResponse(
        # activation_default=Act.tanh,
        aggregation_default=Agg.sum,
        # activation_options=(Act.tanh,),
        aggregation_options=(Agg.sum,),
    ),
    output_transform=jnp.tanh,
)

state = genome.setup()

randkey = jax.random.PRNGKey(42)
nodes, conns = genome.initialize(state, randkey)

network = genome.network_dict(state, nodes, conns)

In [3]:
network

{'nodes': {0: {'idx': 0,
   'bias': array(0.22059791, dtype=float32),
   'agg': 'sum',
   'act': 'sigmoid'},
  1: {'idx': 1,
   'bias': array(0.7715081, dtype=float32),
   'agg': 'sum',
   'act': 'sigmoid'},
  2: {'idx': 2,
   'bias': array(1.1184921, dtype=float32),
   'agg': 'sum',
   'act': 'sigmoid'},
  3: {'idx': 3,
   'bias': array(0.6967973, dtype=float32),
   'agg': 'sum',
   'act': 'sigmoid'},
  4: {'idx': 4,
   'bias': array(0.85948837, dtype=float32),
   'agg': 'sum',
   'act': 'sigmoid'},
  5: {'idx': 5,
   'bias': array(0.19332138, dtype=float32),
   'agg': 'sum',
   'act': 'sigmoid'},
  6: {'idx': 6,
   'bias': array(-0.31763914, dtype=float32),
   'agg': 'sum',
   'act': 'sigmoid'},
  7: {'idx': 7,
   'bias': array(0.05656302, dtype=float32),
   'agg': 'sum',
   'act': 'sigmoid'}},
 'conns': {(0, 6): {'in': 0,
   'out': 6,
   'weight': array(1.6676894, dtype=float32)},
  (0, 7): {'in': 0, 'out': 7, 'weight': array(-0.05250553, dtype=float32)},
  (1, 6): {'in': 1, 'out': 

In [11]:
import sympy as sp

# symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network)
symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, jax_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh)
symbols, args_symbols, input_symbols, nodes_exprs, output_exprs, np_forward_func = genome.sympy_func(state, network, sympy_output_transform=sp.tanh, backend='numpy')


In [12]:
random_inputs = np.random.randn(3).astype(np.float32)
random_inputs, random_inputs.dtype

(array([1.0719017 , 0.09353136, 0.22664611], dtype=float32), dtype('float32'))

In [13]:
transformed = genome.transform(state, nodes, conns)
res = genome.forward(state, transformed, random_inputs)
res

Array([ 0.9743453,  0.5764604, -0.3080282], dtype=float32, weak_type=True)

In [14]:
res1 = np.array(jax_forward_func(random_inputs), dtype=np.float32)
res2 = np.array(np_forward_func(random_inputs), dtype=np.float32)
res = np.array(genome.forward(state, transformed, random_inputs))
res1, res2, res

(array([ 0.9743453,  0.5764604, -0.3080282], dtype=float32),
 array([ 0.9743453 ,  0.57646036, -0.3080282 ], dtype=float32),
 array([ 0.9743453,  0.5764604, -0.3080282], dtype=float32))

In [15]:
res1 == res, res2 == res

(array([ True,  True,  True]), array([ True, False,  True]))

In [23]:
np.floor(res1 * 10000000) / 10000000 == np.floor(res2 * 10000000) / 10000000

array([False, False,  True])