# Tutorial 1: Genome
The genome is the core component of TensorNEAT. It represents the network’s genotype.


Before using the Genome, we need to create a `Genome` instance, which controls the behavior of the genome in use. After creating it, use `setup` to initialize.

In [66]:
from tensorneat.genome import DefaultGenome, BiasNode, DefaultConn, DefaultMutation

genome = DefaultGenome(
    num_inputs=3,  # 3 inputs
    num_outputs=1,  # 1 output
    max_nodes=5,  # the network will have at most 5 nodes
    max_conns=10,  # the network will have at most 10 connections
    node_gene=BiasNode(),  # node with 3 attributes: bias, aggregation, activation
    conn_gene=DefaultConn(),  # connection with 1 attribute: weight
    mutation=DefaultMutation(
        node_add=0.9,
        node_delete=0.0,
        conn_add=0.9,
        conn_delete=0.0
    )  # high mutation rate for testing
)

state = genome.setup()

After creating the genome, we can use it to perform various network operations, including random generation, forward passes, mutation, distance calculation, and more. These operations are JIT-compilable and also support vectorization with `jax.vmap`.

In [67]:
import jax, jax.numpy as jnp

# Initialize a network
nodes, conns = genome.initialize(state, jax.random.PRNGKey(0))

# Network forward
single_input = jax.random.normal(jax.random.PRNGKey(1), (3, ))
transformed = genome.transform(state, nodes, conns)
output = genome.forward(state, transformed, single_input)
print(f"{output=}")

# Network batch forward
batch_inputs = jax.random.normal(jax.random.PRNGKey(2), (10, 3))
batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))(state, transformed, batch_inputs)
print(f"{batch_output=}")

output=Array([-1.5388116], dtype=float32, weak_type=True)
batch_output=Array([[ 2.5477247 ],
       [ 2.2334106 ],
       [ 1.8713341 ],
       [-3.7539673 ],
       [ 1.5344429 ],
       [ 2.7640016 ],
       [ 0.5649997 ],
       [-0.32709932],
       [ 3.5273829 ],
       [-0.64774114]], dtype=float32, weak_type=True)


In [68]:
# Initialize a population of networks (1000 networks)
pop_nodes, pop_conns = jax.vmap(genome.initialize, in_axes=(None, 0))(
    state, jax.random.split(jax.random.PRNGKey(0), 1000)
)

# Population forward
pop_inputs = jax.random.normal(jax.random.PRNGKey(1), (1000, 3))
pop_transformed = jax.vmap(genome.transform, in_axes=(None, 0, 0))(
    state, pop_nodes, pop_conns
)
pop_outputs = jax.vmap(genome.forward, in_axes=(None, 0, 0))(
    state, pop_transformed, pop_inputs
)

print(f"{pop_outputs.shape=}")

pop_outputs.shape=(1000, 1)


In [69]:
# visualize the network
network = genome.network_dict(state, nodes, conns)  # Transform the network from JAX arrays to a Python dict
genome.visualize(network, save_path="./origin_network.svg")

![origin_network](./origin_network.svg)

In [70]:
# Mutate the network
mutated_nodes, mutated_conns = genome.execute_mutation(
    state,
    jax.random.PRNGKey(2),
    nodes,
    conns,
    new_node_key=jnp.asarray(5),  # at most 1 node can be added in each mutation
    new_conn_keys=jnp.asarray([6, 7, 8]),  # at most 3 connections can be added
)

# Visualize the mutated network
mutated_network = genome.network_dict(state, mutated_nodes, mutated_conns)
genome.visualize(mutated_network, save_path="./mutated_network.svg")

![mutated_networ](./mutated_network.svg)