In [1]:
import hydra
import jax
import jax.numpy as jnp
from omegaconf import DictConfig, OmegaConf
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from neat_jax import Mutations, forward, init_network, log_config, Network, sample_from_mask
import chex


INPUT_SIZE = 3
OUTPUT_SIZE = 2

In [2]:
with initialize("neat_jax/configs", version_base="1.3.2"):
    config = compose(config_name="default_config")

In [10]:
OmegaConf.set_struct(config, False)
log_config(config)

config.mutations.activation_fn_mutation_rate = 1.0
config.mutations.add_node_rate = 1.0

config.input_size = INPUT_SIZE
config.output_size = OUTPUT_SIZE

key = jax.random.key(config.params.seed)
inputs = jax.random.normal(key, (config.input_size,))

net, activation_state = init_network(
    inputs, config.output_size, config.network.max_nodes, key
)
mutations = Mutations(max_nodes=config.network.max_nodes, **config.mutations)
net = mutations.add_node(key, net)
activation_state, y = forward(inputs, net, config)

print(net, y)

[34m[1mRunning Neat experiment:
[32m[1mHyperparameters:[22m{
    "env": {
        "kwargs": {},
        "scenario": {
            "name": "CartPole-v1",
            "task_name": "cartpole"
        },
        "suit_name": "gymnax"
    },
    "input_size": 3,
    "mutations": {
        "activation_fn_mutation_rate": 1.0,
        "add_connection_rate": 0.05,
        "add_node_rate": 1.0,
        "weight_mutation_rate": 0.1,
        "weight_shift_rate": 0.5
    },
    "network": {
        "max_connections": 45,
        "max_nodes": 30
    },
    "output_size": 2,
    "params": {
        "population_size": 1024,
        "seed": 0,
        "total_generations": 10000000.0
    }
}[0m
input_size: 3
output_size: 2
max_nodes: 30
node_indices: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29]
node_types: [0 0 0 2 2 1 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3]
activation_fns: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
senders