In [1]:
import jax, jax.numpy as jnp
from algorithm.neat.genome import *
from algorithm.neat.gene import *

jnp.set_printoptions(precision=2, linewidth=150)

In [2]:
# genome = DefaultGenome(num_inputs=3, num_outputs=2, max_nodes=10, max_conns=10)
# state = genome.setup()
# randkey = jax.random.key(0)
# genome_key, input_key = jax.random.split(randkey)
# nodes, conns = genome.initialize(state, genome_key)
# inputs = jax.random.normal(input_key, (10, 3)) * 2 + 1  # std: 2, mean: 1
# print(nodes, conns, sep='\n')
# print(inputs)

In [3]:
# transformed = genome.transform(state, nodes, conns)
# batch_output = jax.vmap(genome.forward, in_axes=(None, 0, None))(state, inputs, transformed)
# batch_output, transformed

In [4]:
# batch_output2, new_transformed = genome.update_by_batch(state, inputs, transformed)
# batch_output2, new_transformed

In [5]:
# assert jnp.allclose(new_transformed[0], transformed[0], equal_nan=True)
# assert jnp.allclose(new_transformed[1], transformed[1], equal_nan=True)
# assert jnp.allclose(new_transformed[2], transformed[2], equal_nan=True)

In [6]:
from algorithm.neat.gene.node.normalized import NormalizedNode
from algorithm.neat.gene.conn import DefaultConnGene
from tensorneat.utils import ACT

genome = DefaultGenome(num_inputs=3, num_outputs=2, max_nodes=10, max_conns=10,
                       node_gene=NormalizedNode(activation_default=ACT.identity, activation_options=(ACT.identity,)),
                       conn_gene=DefaultConnGene(weight_init_mean=1))
state = genome.setup()
randkey = jax.random.key(0)
genome_key, input_key = jax.random.split(randkey)
nodes, conns = genome.initialize(state, genome_key)
nodes = nodes.at[:, 1:].set(genome.node_gene.new_custom_attrs(state))
conns = conns.at[:, 3:].set(genome.conn_gene.new_custom_attrs(state))

inputs = jax.random.normal(input_key, (10000, 3)) * 2 + 1  # std: 2, mean: 1
print(nodes, conns, sep='\n')
print(inputs)
transformed = genome.transform(state, nodes, conns)
transformed

[[ 0.  0.  0.  0.  0.  1.  1.  0.]
 [ 1.  0.  0.  0.  0.  1.  1.  0.]
 [ 2.  0.  0.  0.  0.  1.  1.  0.]
 [ 3.  0.  0.  0.  0.  1.  1.  0.]
 [ 4.  0.  0.  0.  0.  1.  1.  0.]
 [ 5.  0.  0.  0.  0.  1.  1.  0.]
 [nan  0.  0.  0.  0.  1.  1.  0.]
 [nan  0.  0.  0.  0.  1.  1.  0.]
 [nan  0.  0.  0.  0.  1.  1.  0.]
 [nan  0.  0.  0.  0.  1.  1.  0.]]
[[ 0.  5.  1.  1.]
 [ 1.  5.  1.  1.]
 [ 2.  5.  1.  1.]
 [ 5.  3.  1.  1.]
 [ 5.  4.  1.  1.]
 [nan nan nan  1.]
 [nan nan nan  1.]
 [nan nan nan  1.]
 [nan nan nan  1.]
 [nan nan nan  1.]]
[[-1.9  -3.53  0.94]
 [ 2.92  0.06  3.44]
 [-0.9  -0.06  2.94]
 ...
 [ 2.07 -1.43  1.55]
 [ 1.93  2.85  0.19]
 [ 0.91 -0.65  1.86]]


(Array([         0,          1,          2,          5,          3,          4, 2147483647, 2147483647, 2147483647, 2147483647],      dtype=int32, weak_type=True),
 Array([[ 0.,  0.,  0.,  0.,  0.,  1.,  1.,  0.],
        [ 1.,  0.,  0.,  0.,  0.,  1.,  1.,  0.],
        [ 2.,  0.,  0.,  0.,  0.,  1.,  1.,  0.],
        [ 3.,  0.,  0.,  0.,  0.,  1.,  1.,  0.],
        [ 4.,  0.,  0.,  0.,  0.,  1.,  1.,  0.],
        [ 5.,  0.,  0.,  0.,  0.,  1.,  1.,  0.],
        [nan,  0.,  0.,  0.,  0.,  1.,  1.,  0.],
        [nan,  0.,  0.,  0.,  0.,  1.,  1.,  0.],
        [nan,  0.,  0.,  0.,  0.,  1.,  1.,  0.],
        [nan,  0.,  0.,  0.,  0.,  1.,  1.,  0.]], dtype=float32, weak_type=True),
 Array([[[nan, nan, nan, nan, nan,  1., nan, nan, nan, nan],
         [nan, nan, nan, nan, nan,  1., nan, nan, nan, nan],
         [nan, nan, nan, nan, nan,  1., nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]

In [7]:
batch_output2 = jax.vmap(genome.forward, in_axes=(None, 0, None))(state, inputs, transformed)
batch_output2

Array([[-4.49, -4.49],
       [ 6.42,  6.42],
       [ 1.98,  1.98],
       ...,
       [ 2.19,  2.19],
       [ 4.97,  4.97],
       [ 2.12,  2.12]], dtype=float32, weak_type=True)

In [8]:
batch_output, new_transformed = genome.update_by_batch(state, inputs, transformed)

batch_z: [-4.49  6.42  1.98 ...  2.19  4.97  2.12]
batch_z_mean: 2.9496588706970215
batch_z: [-2.15  1.   -0.28 ... -0.22  0.58 -0.24]
batch_z_mean: -2.1362303925798187e-08
batch_z: [-2.15  1.   -0.28 ... -0.22  0.58 -0.24]
batch_z_mean: -2.1362303925798187e-08


In [9]:
batch_output, new_transformed

(Array([[-2.15, -2.15],
        [ 1.  ,  1.  ],
        [-0.28, -0.28],
        ...,
        [-0.22, -0.22],
        [ 0.58,  0.58],
        [-0.24, -0.24]], dtype=float32, weak_type=True),
 (Array([         0,          1,          2,          5,          3,          4, 2147483647, 2147483647, 2147483647, 2147483647],      dtype=int32, weak_type=True),
  Array([[ 0.00e+00,  0.00e+00,  0.00e+00,  0.00e+00,  0.00e+00,  1.00e+00,  1.00e+00,  0.00e+00],
         [ 1.00e+00,  0.00e+00,  0.00e+00,  0.00e+00,  0.00e+00,  1.00e+00,  1.00e+00,  0.00e+00],
         [ 2.00e+00,  0.00e+00,  0.00e+00,  0.00e+00,  0.00e+00,  1.00e+00,  1.00e+00,  0.00e+00],
         [ 3.00e+00,  0.00e+00,  0.00e+00,  0.00e+00, -2.14e-08,  1.00e+00,  1.00e+00,  0.00e+00],
         [ 4.00e+00,  0.00e+00,  0.00e+00,  0.00e+00, -2.14e-08,  1.00e+00,  1.00e+00,  0.00e+00],
         [ 5.00e+00,  0.00e+00,  0.00e+00,  0.00e+00,  2.95e+00,  3.46e+00,  1.00e+00,  0.00e+00],
         [      nan,  0.00e+00,  0.00e+00,  0.00e+0

In [10]:
batch_output2 = jax.vmap(genome.forward, in_axes=(None, 0, None))(state, inputs, new_transformed)
batch_output2

Array([[-2.15, -2.15],
       [ 1.  ,  1.  ],
       [-0.28, -0.28],
       ...,
       [-0.22, -0.22],
       [ 0.58,  0.58],
       [-0.24, -0.24]], dtype=float32, weak_type=True)