In [1]:
import sys
import jax
import jax.numpy as jnp
from jax import random
from functools import partial
import chex

sys.path.append("../")

from neat_jax import (
    Network,
    ActivationState,
    make_network,
    plot_network,
    Mutations,
    update_depth,
)

In [2]:
topology_config_0 = {
    "max_nodes": 10,
    "senders": jnp.array([0, 1, 2, 4]),
    "receivers": jnp.array([4, 4, 3, 3]),
    "weights": jnp.array([1, 1, 1, 1]),
    "activation_indices": jnp.array([0, 0, 0, 0, 0]),
    "node_types": jnp.array([0, 0, 0, 2, 1]),
    "inputs": jnp.array([0.5, 0.8, 0.2]),
    "output_size": 1,
}
mutation_config_certain = {
    "weight_shift_rate": 1,
    "weight_mutation_rate": 1,
    "add_node_rate": 1,
    "add_connection_rate": 1,
}
mutation_config_null = {
    "weight_shift_rate": 0.0,
    "weight_mutation_rate": 0.0,
    "add_node_rate": 0.0,
    "add_connection_rate": 0.0,
}
mutation_config_0_5 = {
    "weight_shift_rate": 0.5,
    "weight_mutation_rate": 0.5,
    "add_node_rate": 0.5,
    "add_connection_rate": 0.5,
}

key = random.PRNGKey(0)
activation_state, net = make_network(**topology_config_0)
net

node_indices: [0 1 2 3 4 5 6 7 8 9]
node_types: [0 0 0 2 1 3 3 3 3 3]
activation_indices: [0 0 0 0 0 0 0 0 0 0]
weights: [1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
senders: [  0   1   2   4 -10 -10 -10 -10 -10 -10]
receivers: [  4   4   3   3 -10 -10 -10 -10 -10 -10]
output_size: 1




In [5]:
configs = [mutation_config_null, mutation_config_0_5, mutation_config_certain]
seeds = [0, 1, 0]
max_nodes = topology_config_0["max_nodes"]

for i in range(3):
    config = configs[i]
    key = random.PRNGKey(seeds[i])

    print("-" * 20)
    mutations = Mutations(max_nodes=max_nodes, **config)
    print("Weight shift weights:")
    print(mutations.weight_shift(key, net).weights)
    print("Weight mutation weights:")
    print(mutations.weight_mutation(key, net).weights)
    print("Add node network")
    print(mutations.add_node(key, net, 1))
    print("Add connection network")
    print(mutations.add_connection(key, net, activation_state, max_nodes))

--------------------
Weight shift weights:
[1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
Weight mutation weights:
[1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
Add node network
node_indices: [0 1 2 3 4 5 6 7 8 9]
node_types: [0 0 0 2 1 3 3 3 3 3]
activation_indices: [0 0 0 0 0 0 0 0 0 0]
weights: [1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
senders: [  0   1   2   4 -10 -10 -10 -10 -10 -10]
receivers: [  4   4   3   3 -10 -10 -10 -10 -10 -10]
output_size: 1

Add connection network
node_indices: [0 1 2 3 4 5 6 7 8 9]
node_types: [0 0 0 2 1 3 3 3 3 3]
activation_indices: [0 0 0 0 0 0 0 0 0 0]
weights: [1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
senders: [  0   1   2   4 -10 -10 -10 -10 -10 -10]
receivers: [  4   4   3   3 -10 -10 -10 -10 -10 -10]
output_size: 1

--------------------
Weight shift weights:
[0.9356227 1.        0.9701904 1.        0.        0.        0.
 0.        0.        0.       ]
Weight mutation weights:
[-0.06437729  1.         -0.02980961  1.          0.          0.
  0.          0.          0.          0.        ]
Add