# Tutorial 2: Usage

This tutorial introduces how to use TensorNEAT to solve problems.  

TensorNEAT provides a **pipeline**, allowing users to run the NEAT algorithm efficiently after setting up the required components (problem, algorithm, and pipeline).  
Once everything is ready, users can call `pipeline.auto_run()` to execute the NEAT algorithm.  
The `auto_run()` method maximizes performance by parallelizing the execution using `jax.vmap` and compiling operations with `jax.jit`, making full use of GPU acceleration.

---

## Types of Problems in TensorNEAT  

The problems to be solved using TensorNEAT can be categorized into the following cases:

1. **Problems already provided by TensorNEAT** (Function Fit, Gymnax, Brax)  
   - In this case, users can directly create a pipeline and execute it.

2. **Problems not provided by TensorNEAT but are JIT-compatible** (supporting `jax.jit`)  
   - Users need to create a **Custom Problem class**, then create a pipeline for execution.

3. **Problems not provided by TensorNEAT and not JIT-compatible**  
   - In this case, users **cannot** create a pipeline for direct execution. Instead, the NEAT algorithm must be manually executed.  
   - The detailed method for manual execution is explained below.

## 2.1 Using Existing Benchmarks

TensorNEAT currently provides benchmarks for **Function Fit (Symbolic Regression)** and **Reinforcement Learning (RL) tasks** using **Gymnax** and **Brax**.  

If you want to use these predefined problems, refer to the **examples** for implementation details.

## 2.2 Custom Jitable Problem
The following code demonstrates how users can define a custom problem and create a pipeline for automatic execution:

In [20]:
# Prepartion
import jax, jax.numpy as jnp
from tensorneat.problem import BaseProblem

# The problem is to fit pagie_polynomial
def pagie_polynomial(inputs):
    x, y = inputs
    res = 1 / (1 + jnp.pow(x, -4)) + 1 / (1 + jnp.pow(y, -4))

    # Important! Returns an array with one item, NOT a scalar
    return jnp.array([res])

# Create dataset (10 samples)
INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (10, 2))
LABELS = jax.vmap(pagie_polynomial)(INPUTS)

print(f"{INPUTS.shape=}, {LABELS.shape=}")

INPUTS.shape=(10, 2), LABELS.shape=(10, 1)


In [21]:
# Define the custom Problem
class CustomProblem(BaseProblem):

    jitable = True # necessary

    def evaluate(self, state, randkey, act_func, params):
        # Use ``act_func(state, params, inputs)`` to do network forward

        # do batch forward for all inputs (using jax.vamp)
        predict = jax.vmap(act_func, in_axes=(None, None, 0))(
            state, params, INPUTS
        )  # should be shape (1000, 1)

        # calculate loss
        loss = jnp.mean(jnp.square(predict - LABELS))

        # return negative loss as fitness 
        # TensorNEAT maximizes fitness, equivalent to minimizes loss
        return -loss

    @property
    def input_shape(self):
        # the input shape that the act_func expects
        return (2, )
    
    @property
    def output_shape(self):
        # the output shape that the act_func returns
        return (1, )
    
    def show(self, state, randkey, act_func, params, *args, **kwargs):
        # shocase the performance of one individual
        predict = jax.vmap(act_func, in_axes=(None, None, 0))(
            state, params, INPUTS
        )

        loss = jnp.mean(jnp.square(predict - LABELS))

        msg = ""
        for i in range(INPUTS.shape[0]):
            msg += f"input: {INPUTS[i]}, target: {LABELS[i]}, predict: {predict[i]}\n"
        msg += f"loss: {loss}\n"
        print(msg)

In [22]:
import jax.numpy as jnp

from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, BiasNode
from tensorneat.problem.func_fit import CustomFuncFit
from tensorneat.common import ACT, AGG

# Construct the pipeline and run
pipeline = Pipeline(
    algorithm=NEAT(
        pop_size=1000,
        species_size=20,
        survival_threshold=0.01,
        genome=DefaultGenome(
            num_inputs=2,
            num_outputs=1,
            init_hidden_layers=(),
            node_gene=BiasNode(
                activation_options=[ACT.identity, ACT.inv],
                aggregation_options=[AGG.sum, AGG.product],
            ),
            output_transform=ACT.identity,
        ),
    ),
    problem=CustomProblem(),
    generation_limit=5,
    fitness_target=-1e-4,
    seed=42,
)

# initialize state
state = pipeline.setup()
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)

initializing
initializing finished
start compile
compile finished, cost time: 10.663079s
Generation: 1, Cost time: 64.31ms
 	fitness: valid cnt: 1000, max: -0.0187, min: -15.4957, mean: -1.5639, std: 2.0848

	node counts: max: 4, min: 3, mean: 3.10
 	conn counts: max: 3, min: 0, mean: 1.88
 	species: 20, [593, 11, 15, 6, 27, 25, 47, 41, 14, 27, 22, 4, 12, 35, 28, 41, 16, 17, 6, 13]

Generation: 2, Cost time: 65.52ms
 	fitness: valid cnt: 999, max: -0.0104, min: -120.0426, mean: -0.5639, std: 4.4452

	node counts: max: 5, min: 3, mean: 3.17
 	conn counts: max: 4, min: 0, mean: 1.87
 	species: 20, [112, 139, 194, 52, 1, 71, 53, 95, 39, 25, 14, 2, 10, 35, 37, 7, 1, 87, 9, 17]

Generation: 3, Cost time: 59.10ms
 	fitness: valid cnt: 975, max: -0.0057, min: -57.8308, mean: -0.1830, std: 1.8740

	node counts: max: 6, min: 3, mean: 3.49
 	conn counts: max: 6, min: 0, mean: 2.47
 	species: 20, [35, 126, 43, 114, 1, 73, 9, 65, 321, 17, 51, 5, 35, 24, 14, 20, 1, 6, 37, 3]

Generation: 4, Cost ti

### 2.3 Custom Un-Jitable Problem  
This scenario is more complex because we cannot directly construct a pipeline to run the NEAT algorithm. The following code demonstrates how to use TensorNEAT to execute an un-jitable custom problem.

In [23]:
# We use Cartpole in gymnasium as the Un-jitable problem
import gymnasium as gym
env = gym.make("CartPole-v1")

from tensorneat.common import State
# Define the genome and Pre jit necessary functions in genome
genome=DefaultGenome(
    num_inputs=4,
    num_outputs=2,
    init_hidden_layers=(),
    node_gene=BiasNode(),
    output_transform=jnp.argmax,
)
state = State(randkey=jax.random.key(0))
state = genome.setup(state)

jit_transform = jax.jit(genome.transform)
jit_forward = jax.jit(genome.forward)

In [24]:
# define the method to evaluate the individual and the population
from tqdm import tqdm

def evaluate(state, nodes, conns):
    # evaluate the individual
    transformed = jit_transform(state, nodes, conns)

    observation, info = env.reset()
    episode_over, total_reward = False, 0
    while not episode_over:
        action = jit_forward(state, transformed, observation)
        # currently the action is a jax array on gpu
        # we need move it to cpu for env step
        action = jax.device_get(action)

        observation, reward, terminated, truncated, info = env.step(action)
        total_reward += reward
        episode_over = terminated or truncated

    return total_reward

def evaluate_population(state, pop_nodes, pop_conns):
    # evaluate the population
    pop_size = pop_nodes.shape[0]
    fitness = []
    for i in tqdm(range(pop_size)):
        fitness.append(
            evaluate(state, pop_nodes[i], pop_conns[i])
        )

    # return a jax array
    return jnp.asarray(fitness)

In [25]:
# define the algorithm
algorithm = NEAT(
    pop_size=100,
    species_size=20,
    survival_threshold=0.1,
    genome=genome,
)
state = algorithm.setup(state)

# jit for acceleration
jit_algorithm_ask = jax.jit(algorithm.ask)
jit_algorithm_tell = jax.jit(algorithm.tell)

In [26]:
# run!
print("Start running...")
for generation in range(10):
    pop_nodes, pop_conns = jit_algorithm_ask(state)
    print(f"Generation {generation}: evaluating population...")
    fitness = evaluate_population(state, pop_nodes, pop_conns)

    state = jit_algorithm_tell(state, fitness)
    print(f"Generation {generation}: best fitness: {fitness.max()}")

    if fitness.max() >= 500:
        print("Fitness limit reached!")
        break

Start running...
Generation 0: evaluating population...


100%|██████████| 100/100 [00:04<00:00, 22.38it/s]


Generation 0: best fitness: 493.0
Generation 1: evaluating population...


100%|██████████| 100/100 [00:17<00:00,  5.62it/s]


Generation 1: best fitness: 500.0
Fitness limit reached!


The above code runs slowly due to the following reasons:
1. We use a `for` loop to evaluate the fitness of each individual in the population sequentially, lacking parallel acceleration.
2. We do not take advantage of TensorNEAT’s GPU parallel execution capabilities.
3. There are too many switches between Python code and JAX code, causing unnecessary overhead.


The following code demonstrates an optimized `gymnasium` evaluation process:

In [27]:
# use numpy as numpy-python switch takes shorter time than jax-python switch
import numpy as np

jit_batch_transform = jax.jit(jax.vmap(genome.transform, in_axes=(None, 0, 0)))
jit_batch_forward = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, 0)))

POP_SIZE = 100
# Use multiple envs
envs = [gym.make("CartPole-v1") for _ in range(POP_SIZE)]
def accelerated_evaluate_population(state, pop_nodes, pop_conns):
    # transformed the population using batch transfrom
    pop_transformed = jit_batch_transform(state, pop_nodes, pop_conns)

    pop_observation = [env.reset()[0] for env in envs]
    pop_observation = np.asarray(pop_observation)
    pop_fitness = np.zeros(POP_SIZE)
    episode_over = np.zeros(POP_SIZE, dtype=bool)
    
    while not np.all(episode_over):
        # batch forward
        pop_action = jit_batch_forward(state, pop_transformed, pop_observation)
        pop_action = jax.device_get(pop_action)

        obs, reward, terminated, truncated = [], [], [], []
        # we still need to step the envs one by one
        for i in range(POP_SIZE):
            if episode_over[i]:
                # is already terminated
                # append zeros to keep the shape
                obs.append(np.zeros(4))
                reward.append(0.0)
                terminated.append(True)
                truncated.append(False)
                continue
            else:
                # step the env
                obs_, reward_, terminated_, truncated_, info_ = envs[i].step(pop_action[i])
                obs.append(obs_)
                reward.append(reward_)
                terminated.append(terminated_)
                truncated.append(truncated_)

        pop_observation = np.asarray(obs)
        pop_reward = np.asarray(reward)
        pop_terminated = np.asarray(terminated)
        pop_truncated = np.asarray(truncated)

        # update fitness and over
        pop_fitness += pop_reward * ~episode_over
        episode_over = episode_over | pop_terminated | pop_truncated

    return pop_fitness

In [28]:
# Compare the speed between these two methods
# prerun once for jax compile
accelerated_evaluate_population(state, pop_nodes, pop_conns)

import time
time_tic = time.time()
fitness_slow = evaluate_population(state, pop_nodes, pop_conns)
slow_time = time.time() - time_tic

time_tic = time.time()
fitness_fast = accelerated_evaluate_population(state, pop_nodes, pop_conns)
fast_time = time.time() - time_tic

print(f"{slow_time=}, {fast_time=}")

100%|██████████| 100/100 [00:14<00:00,  6.87it/s]


slow_time=14.562041997909546, fast_time=1.134758710861206
