In [1]:
from algorithm.neat.genome import DefaultGenome
from utils.tools import flatten_conns, unflatten_conns
import jax, jax.numpy as jnp
from jax import vmap

In [2]:
genome = DefaultGenome(num_inputs=3, num_outputs=1, max_nodes=5, max_conns=5)
state = genome.setup()
key = jax.random.PRNGKey(0)
nodes, conns = genome.initialize(state, key)
nodes.shape, conns.shape

((5, 5), (5, 4))

In [3]:
transformed = genome.transform(state, nodes, conns)
transformed

(Array([0, 1, 2, 4, 3], dtype=int32, weak_type=True),
 Array([[ 0.        , -1.013169  ,  1.        ,  0.        ,  0.        ],
        [ 1.        , -0.3775248 ,  1.        ,  0.        ,  0.        ],
        [ 2.        ,  0.7407059 ,  1.        ,  0.        ,  0.        ],
        [ 3.        , -0.66817343,  1.        ,  0.        ,  0.        ],
        [ 4.        ,  0.5336131 ,  1.        ,  0.        ,  0.        ]],      dtype=float32, weak_type=True),
 Array([[[        nan,         nan,         nan,         nan,
           0.13149254],
         [        nan,         nan,         nan,         nan,
           0.02001922],
         [        nan,         nan,         nan,         nan,
          -0.79229796],
         [        nan,         nan,         nan,         nan,
                  nan],
         [        nan,         nan,         nan, -0.57102853,
                  nan]]], dtype=float32, weak_type=True))

In [4]:
# single flatten
nodes, conns = genome.restore(state, transformed)
nodes, conns

(Array([[ 0.        , -1.013169  ,  1.        ,  0.        ,  0.        ],
        [ 1.        , -0.3775248 ,  1.        ,  0.        ,  0.        ],
        [ 2.        ,  0.7407059 ,  1.        ,  0.        ,  0.        ],
        [ 3.        , -0.66817343,  1.        ,  0.        ,  0.        ],
        [ 4.        ,  0.5336131 ,  1.        ,  0.        ,  0.        ]],      dtype=float32, weak_type=True),
 Array([[ 1.        ,  0.        ,  4.        ,  0.13149254],
        [ 1.        ,  1.        ,  4.        ,  0.02001922],
        [ 1.        ,  2.        ,  4.        , -0.79229796],
        [ 1.        ,  4.        ,  3.        , -0.57102853],
        [ 1.        ,         nan,         nan,         nan]],      dtype=float32))

In [7]:
conns = jnp.insert(conns, obj=3, values=1, axis=1)
conns

Array([[ 1.        ,  3.        ,  0.        ,  1.        ,  4.        ,
         0.13149254],
       [ 1.        ,  3.        ,  1.        ,  1.        ,  4.        ,
         0.02001922],
       [ 1.        ,  3.        ,  2.        ,  1.        ,  4.        ,
        -0.79229796],
       [ 1.        ,  3.        ,  4.        ,  1.        ,  3.        ,
        -0.57102853],
       [ 1.        ,  3.        ,         nan,  1.        ,         nan,
                nan]], dtype=float32)

In [8]:
# batch_flatten
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 3)
pop_nodes, pop_conns = jax.vmap(genome.initialize, in_axes=(None, 0))(state, keys)
pop_nodes.shape, pop_conns.shape

((3, 10, 5), (3, 10, 4))

In [9]:
pop_unflatten = jax.vmap(unflatten_conns)(pop_nodes, pop_conns)
pop_unflatten.shape

(3, 2, 10, 10)

In [10]:
flatten = jax.vmap(flatten_conns, in_axes=(0, 0, None))(pop_nodes, pop_unflatten, 10)
flatten.shape

(3, 10, 4)