In [1]:
import jax, jax.numpy as jnp

from algorithm.neat import *
from algorithm.neat.genome.advance import AdvanceInitialize
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
from utils.graph import topological_sort_python
from utils import Act, Agg

In [2]:
genome = AdvanceInitialize(
    num_inputs=3,
    num_outputs=1,
    hidden_cnt=1,
    max_nodes=50,
    max_conns=500,
    node_gene=NodeGeneWithoutResponse(
        # activation_default=Act.tanh,
        aggregation_default=Agg.sum,
        # activation_options=(Act.tanh,),
        aggregation_options=(Agg.sum,),
    )
)

state = genome.setup()

randkey = jax.random.PRNGKey(42)
nodes, conns = genome.initialize(state, randkey)

network = genome.network_dict(state, nodes, conns)

In [3]:
import sympy as sp

symbols, input_symbols, nodes_exprs, output_exprs, forward_func = genome.sympy_func(state, network, precision=3, )
output_exprs

[-0.535*sigmoid(0.346*i0 + 0.044*i1 - 0.482*i2 + 0.875) - 0.264]

In [4]:
print(sp.latex(output_exprs[0]))

- 0.535 \mathrm{sigmoid}\left(0.346 i_{0} + 0.044 i_{1} - 0.482 i_{2} + 0.875\right) - 0.264


In [5]:
import numpy as np
random_inputs = np.random.randn(3)
res = forward_func(random_inputs)
res 

[-0.7940936986556304]

In [6]:
transformed = genome.transform(state, nodes, conns)
res = genome.forward(state, transformed, random_inputs)
res

Array([-0.7934886], dtype=float32, weak_type=True)