In [3]:
from pipeline import Pipeline
from algorithm.neat import *
from algorithm.neat.gene.node.kan_node import KANNode
from algorithm.neat.gene.conn.bspline import BSplineConn
from problem.func_fit import XOR3d
from tensorneat.utils import ACT

import jax, jax.numpy as jnp

genome = DefaultGenome(
    num_inputs=3,
    num_outputs=1,
    max_nodes=5,
    max_conns=10,
    node_gene=KANNode(),
    conn_gene=BSplineConn(),
    output_transform=ACT.sigmoid,  # the activation function for output node
    mutation=DefaultMutation(
        node_add=0.1,
        conn_add=0.1,
        node_delete=0.05,
        conn_delete=0.05,
    ),
)
state = genome.setup()
state

State ({'kan_initial_grids': Array([-1. , -0.5,  0. ,  0.5,  1. ], dtype=float32)})

In [4]:
randkey = jax.random.key(0)
nodes, conns = genome.initialize(state, randkey)
nodes, conns

(Array([[0.],
        [1.],
        [2.],
        [3.],
        [4.]], dtype=float32, weak_type=True),
 Array([[ 0.        ,  4.        , -1.        , -0.5       ,  0.        ,
          0.5       ,  1.        ,  0.04929435, -1.2567043 ,  1.1369427 ,
          0.6141437 ,  1.4434636 ,  0.24439397,  0.77281904],
        [ 1.        ,  4.        , -1.        , -0.5       ,  0.        ,
          0.5       ,  1.        ,  0.90565056,  1.4197341 ,  0.82603943,
          1.164936  , -0.74349356,  0.9511131 , -1.5443964 ],
        [ 2.        ,  4.        , -1.        , -0.5       ,  0.        ,
          0.5       ,  1.        ,  1.7152852 , -1.6385511 ,  1.0964565 ,
          0.6741095 ,  1.4752939 , -0.3695403 , -0.5071054 ],
        [ 4.        ,  3.        , -1.        , -0.5       ,  0.        ,
          0.5       ,  1.        , -1.2653785 , -1.2907758 ,  0.6196416 ,
         -0.8124694 , -0.7498491 , -1.582707  , -0.04516089],
        [        nan,         nan,         nan,         n

In [5]:
transformed = genome.transform(state, nodes, conns)
transformed

(Array([0, 1, 2, 4, 3], dtype=int32, weak_type=True),
 Array([[0.],
        [1.],
        [2.],
        [3.],
        [4.]], dtype=float32, weak_type=True),
 Array([[[        nan,         nan,         nan,         nan,
          -1.        ],
         [        nan,         nan,         nan,         nan,
          -1.        ],
         [        nan,         nan,         nan,         nan,
          -1.        ],
         [        nan,         nan,         nan,         nan,
                  nan],
         [        nan,         nan,         nan, -1.        ,
                  nan]],
 
        [[        nan,         nan,         nan,         nan,
          -0.5       ],
         [        nan,         nan,         nan,         nan,
          -0.5       ],
         [        nan,         nan,         nan,         nan,
          -0.5       ],
         [        nan,         nan,         nan,         nan,
                  nan],
         [        nan,         nan,         nan, -0.5       ,
    

In [8]:
res = genome.forward(state, jnp.array([1, 1, 1]), transformed)
res

Array([nan], dtype=float32, weak_type=True)