# NEAT examples — XOR, parameter sweep, and CartPole

Welcome! This notebook shows how to use neat-python (NEAT) for:

- XOR: a tiny supervised learning example for evolving small feedforward networks.
- Parameter sweep: a quick grid search over NEAT hyperparameters (e.g., population size, species threshold).
- CartPole: a short control task using gymnasium's `CartPole-v1`.

What you can do here

- Run the XOR example to watch NEAT improve fitness over generations.
- Try a small sweep to see how hyperparameters affect performance and species counts.
- Evolve a controller for CartPole and inspect the resulting policy.

Quick edits (change these in the 'Constants' cells)

- XOR Constants: `XOR_POP_SIZE`, `N_XOR_GENERATIONS`, `XOR_ACTIVATION_DEFAULT`, `XOR_ACTIVATION_OPTIONS`.
- Parameter Sweep Constants: `SWEEP_POP_SIZES`, `SWEEP_SPECIES_THRESHOLDS`, `SWEEP_N_GENERATIONS`.
- CartPole Constants: `CART_POP_SIZE`, `CART_ACTIVATION_DEFAULT`, `CART_ACTIVATION_OPTIONS` (and change the `n` argument in the CartPole run cell for generations).

How to run

- Local (recommended): create/activate a virtual environment and run cells top-to-bottom, or execute non-interactively with nbconvert:

    ```bash
    python -m venv .venv
    ./.venv/Scripts/Activate.ps1               # PowerShell (Windows)
    python -m pip install -r requirements.txt
    ```

- Google Colab: run the first code cell ("Google Colab setup") to install dependencies, then run cells top-to-bottom.  

  [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RichardJPovinelli/Evolutionary_Computation_Course/blob/main/NEAT_demo.ipynb)

Outputs and reproducibility

- Outputs (summaries, checkpoints) are saved in `results/` and `sweep_results/`.
- For reproducible quick tests, set a seed (e.g., `random.seed(0)`) near the constants and keep cfgs/versioned where possible.

## Setup environment

In [None]:
# Google Colab setup: install required packages when running in Colab
IN_COLAB = False
try:
    import google.colab
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    print('Running in Google Colab — installing dependencies...')
    %pip install neat-python
    %pip install gymnasium[classic-control]
    # If you need additional system packages, install them here (apt-get)
    print('Dependencies installed; you may need to restart the runtime')
else:
    print('Not running in Colab — skipping Colab-specific setup')

In [None]:
# Imports
import os
import json
import random
import statistics
import itertools

import gymnasium as gym
import neat
from neat.nn import FeedForwardNetwork
from neat.reporting import BaseReporter

print('Imports OK')

In [None]:
#Create default NEAT config files in the notebook if they are missing
import os

xor_cfg = '''[NEAT]
fitness_criterion=max
fitness_threshold=3.9
pop_size=100
reset_on_extinction=False

[DefaultGenome]
activation_default=sigmoid
activation_mutate_rate=0.0
activation_options=sigmoid
aggregation_default=sum
aggregation_mutate_rate=0.0
aggregation_options=sum
bias_init_mean=0.0
bias_init_stdev=1.0
bias_max_value=5.0
bias_min_value=-5.0
bias_mutate_power=0.5
bias_mutate_rate=0.7
bias_replace_rate=0.1
compatibility_disjoint_coefficient=1.0
compatibility_weight_coefficient=0.5
conn_add_prob=0.5
conn_delete_prob=0.3
enabled_default=True
enabled_mutate_rate=0.01
feed_forward=True
initial_connection=partial_direct 0.5
node_add_prob=0.2
node_delete_prob=0.05
num_hidden=0
num_inputs=2
num_outputs=1
response_init_mean=1.0
response_init_stdev=0.0
response_max_value=1.0
response_min_value=1.0
response_mutate_rate=0.0
response_mutate_power=0.0
response_replace_rate=0.0
weight_init_mean=0.0
weight_init_stdev=1.0
weight_max_value=5.0
weight_min_value=-5.0
weight_mutate_power=0.5
weight_mutate_rate=0.8
weight_replace_rate=0.1

[DefaultSpeciesSet]
compatibility_threshold=3.0

[DefaultStagnation]
species_fitness_func=max
max_stagnation=15
species_elitism=1

[DefaultReproduction]
elitism=2
survival_threshold=0.2
'''

sweep_template = '''[NEAT]
fitness_criterion     = max
fitness_threshold     = 3.9
pop_size              = {pop_size}
reset_on_extinction   = False

[DefaultGenome]
activation_default      = sigmoid
activation_mutate_rate  = 0.0
activation_options      = sigmoid
aggregation_default     = sum
aggregation_mutate_rate = 0.0
aggregation_options     = sum
bias_init_mean          = 0.0
bias_init_stdev         = 1.0
bias_max_value          = 5.0
bias_min_value          = -5.0
bias_mutate_power       = 0.5
bias_mutate_rate        = 0.7
bias_replace_rate       = 0.1
compatibility_disjoint_coefficient = 1.0
compatibility_weight_coefficient   = 0.5
conn_add_prob           = 0.5
conn_delete_prob        = 0.3
enabled_default         = True
enabled_mutate_rate     = 0.01
feed_forward            = True
initial_connection      = partial_direct 0.5
node_add_prob           = 0.2
node_delete_prob        = 0.05
num_hidden              = 0
num_inputs              = 2
num_outputs             = 1
response_init_mean      = 1.0
response_init_stdev     = 0.0
response_max_value      = 1.0
response_min_value      = 1.0
response_mutate_rate    = 0.0
response_mutate_power   = 0.0
response_replace_rate    = 0.0
weight_init_mean        = 0.0
weight_init_stdev       = 1.0
weight_max_value        = 5.0
weight_min_value        = -5.0
weight_mutate_power     = 0.5
weight_mutate_rate      = 0.8
weight_replace_rate     = 0.1

[DefaultSpeciesSet]
compatibility_threshold = {compat_threshold}

[DefaultStagnation]
species_fitness_func = max
max_stagnation       = 12
species_elitism      = 1

[DefaultReproduction]
elitism            = 2
survival_threshold = 0.2
'''

cartpole_cfg = '''[NEAT]
fitness_criterion     = max
fitness_threshold     = 475.0
pop_size              = 150
reset_on_extinction   = False

[DefaultGenome]
activation_default      = tanh
activation_mutate_rate  = 0.0
activation_options      = tanh
aggregation_default     = sum
aggregation_mutate_rate = 0.0
aggregation_options     = sum
bias_init_mean          = 0.0
bias_init_stdev         = 1.0
bias_max_value          = 5.0
bias_min_value          = -5.0
bias_mutate_power       = 0.5
bias_mutate_rate        = 0.7
bias_replace_rate       = 0.1
compatibility_disjoint_coefficient = 1.0
compatibility_weight_coefficient   = 0.5
conn_add_prob           = 0.5
conn_delete_prob        = 0.3
enabled_default         = True
enabled_mutate_rate     = 0.01
feed_forward            = True
initial_connection      = full_direct
node_add_prob           = 0.2
node_delete_prob        = 0.05
num_hidden              = 0
num_inputs              = 4
num_outputs             = 1
response_init_mean      = 1.0
response_init_stdev     = 0.0
response_max_value      = 1.0
response_min_value      = 1.0
response_mutate_rate    = 0.0
response_mutate_power   = 0.0
response_replace_rate   = 0.0
weight_init_mean        = 0.0
weight_init_stdev       = 1.0
weight_max_value        = 5.0
weight_min_value        = -5.0
weight_mutate_power     = 0.5
weight_mutate_rate      = 0.8
weight_replace_rate     = 0.1

[DefaultSpeciesSet]
compatibility_threshold = 3.0

[DefaultStagnation]
species_fitness_func = max
max_stagnation       = 15
species_elitism      = 1

[DefaultReproduction]
elitism            = 2
survival_threshold = 0.2
'''

def write_if_missing(path, text):
    if not os.path.exists(path):
        with open(path, 'w', encoding='utf-8') as fh:
            fh.write(text)
        print(f'Wrote {path}')
    else:
        print(f'{path} already exists, skipping')

write_if_missing('xor_config.cfg', xor_cfg)
write_if_missing('sweep_config_template.cfg', sweep_template)
write_if_missing('cartpole_config.cfg', cartpole_cfg)
print('Config files checked/created (if missing).')

## XOR problem: inputs, labels, and evaluation

This section evolves a small feedforward network to solve the XOR truth table.

- Dataset: four 2D inputs with binary targets (0/1): (0,0→0), (0,1→1), (1,0→1), (1,1→0).
- Fitness: for each case we score `1.0 - (output - target)^2` and sum across the four cases (higher is better; maximum = 4.0).
- Network I/O: 2 inputs, 1 output. The raw output is printed; you can apply a threshold at 0.5 to interpret it as a class if desired.
- What to expect: with small populations and few generations, a perfect solution may not appear every run — watch the reporter output and see checkpoints in `sweep_results/`.

Tips:
- Reduce `N_XOR_GENERATIONS` for faster classroom demos; increase it (or `XOR_POP_SIZE`) for more reliable convergence.
- Try different activation functions via `XOR_ACTIVATION_DEFAULT` / `XOR_ACTIVATION_OPTIONS`.

### XOR Constants

In [None]:
XOR_ACTIVATION_DEFAULT = 'sigmoid'
XOR_ACTIVATION_OPTIONS = ['sigmoid'] # ['tanh', 'sigmoid', 'relu]
XOR_POP_SIZE = 100
N_XOR_GENERATIONS = 100

xor_input_pairs = [(0.0, 0.0), (0.0, 1.0), (1.0, 0.0), (1.0, 1.0)]
xor_expected_outputs = [0.0, 1.0, 1.0, 0.0]

## Main XOR Code

In [None]:
def compute_genome_xor_fitness(genome, config):
    network = FeedForwardNetwork.create(genome, config)
    fitness = 0.0
    for inputs, target in zip(xor_input_pairs, xor_expected_outputs):
        output = network.activate(inputs)[0]
        fitness += 1.0 - (output - target) ** 2
    return fitness

def evaluate_population_xor(genomes, config):
    for _, genome in genomes:
        genome.fitness = compute_genome_xor_fitness(genome, config)

# Run action single XOR experiment (short smoke test)
config_path = 'xor_config.cfg'
if not os.path.exists(config_path):
    raise FileNotFoundError(f"{config_path} not found. Add it to the code folder.")

config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, neat.DefaultSpeciesSet, neat.DefaultStagnation, config_path)
config.pop_size = XOR_POP_SIZE
config.genome_config.activation_default = XOR_ACTIVATION_DEFAULT
config.genome_config.activation_options = XOR_ACTIVATION_OPTIONS

population = neat.Population(config)
population.add_reporter(neat.StdOutReporter(True))
statistics_reporter = neat.StatisticsReporter()
population.add_reporter(statistics_reporter)
os.makedirs('sweep_results', exist_ok=True)
population.add_reporter(neat.Checkpointer(5, filename_prefix='sweep_results/xor-'))
best_genome = population.run(evaluate_population_xor, n=N_XOR_GENERATIONS)
if best_genome is None:
    print('\nNo best_genome produced by population run')
else:
    print('\nBest genome:', best_genome)
    network = FeedForwardNetwork.create(best_genome, config)
    for inputs, target in zip(xor_input_pairs, xor_expected_outputs):
        network_output = network.activate(inputs)[0]
        print(f"Input={inputs} -> output={network_output:.3f}, target={target}")

## Parameter sweep
A compact grid search over two NEAT hyperparameters with small, repeatable runs and a JSON summary.

What it does
- Defines a grid over population size (`SWEEP_POP_SIZES`) and species compatibility threshold (`SWEEP_SPECIES_THRESHOLDS`).
- For each grid point, generates a config from `sweep_config_template.cfg` and writes it to `sweep_results/run_<id>_cfg.cfg`.
- Runs two replicates (random seeds 0 and 1) using the XOR evaluator (`evaluate_population_xor`) for `SWEEP_N_GENERATIONS` generations.
- For each replicate, records the best genome’s fitness and tracks the number of species per generation.
- Aggregates across replicates to compute:
  - `best_fitness_mean` and `best_fitness_stdev` (population stdev) of the best fitness
  - `avg_species_last_gen`: average number of species in the final generation
- Prints a progress line per grid point and writes a machine-readable summary to `sweep_results/sweep_summary.json`.

Outputs
- One generated NEAT config per grid point under `sweep_results/`.
- Summary file: `sweep_results/sweep_summary.json`.

Customize it
- Change the grid via `SWEEP_POP_SIZES` and `SWEEP_SPECIES_THRESHOLDS`.
- Adjust runtime via `SWEEP_N_GENERATIONS` and the replicate seeds in the code (currently `[0, 1]`).
- Try different activation functions via `SWEEP_ACTIVATION_DEFAULT` / `SWEEP_ACTIVATION_OPTIONS` (used inside the helper).
- Switch tasks by replacing `evaluate_population_xor` with your own evaluator.
- To get checkpoints or richer logs during the sweep, add NEAT reporters (e.g., `StdOutReporter`, `Checkpointer`) inside `run_single_experiment`.

Notes
- Runtime scales roughly with |grid| × replicates × generations; start small for quick demos.

In [None]:
SWEEP_ACTIVATION_OPTIONS = ['sigmoid', 'tanh', 'relu']
SWEEP_POP_SIZES = [50, 100, 200]
SWEEP_SPECIES_THRESHOLDS = [1.0, 2.5, 3.0]
SWEEP_N_GENERATIONS = 100
SWEEP_ACTIVATION_DEFAULT = 'sigmoid' 

In [None]:
# Helper to run one experiment (returns best fitness and species counts)
def run_single_experiment(config_path, max_generations=20):
    config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, neat.DefaultSpeciesSet, neat.DefaultStagnation, config_path)
    config.genome_config.activation_default = XOR_ACTIVATION_DEFAULT
    config.genome_config.activation_options = XOR_ACTIVATION_OPTIONS
    population = neat.Population(config)
    species_counts = []

    class SpeciesSpy(BaseReporter):
        def post_evaluate(self, config, population, species, best_genome):
            species_counts.append(len(species.species))

    population.add_reporter(SpeciesSpy())
    best_genome = population.run(evaluate_population_xor, n=SWEEP_N_GENERATIONS)
    if best_genome is None:
        return 0.0, species_counts
    return best_genome.fitness, species_counts


# Small sweep (fast settings for smoke test)
os.makedirs('sweep_results', exist_ok=True)
# compat_threshold is the distance threshold for speciation
param_grid = {'pop_size': SWEEP_POP_SIZES, 'compat_threshold': SWEEP_SPECIES_THRESHOLDS}
sweep_results = []
sweep_run_id = 0
for pop_size, compat_threshold in itertools.product(param_grid['pop_size'], param_grid['compat_threshold']):
    sweep_run_id += 1
    generated_cfg_text = open('sweep_config_template.cfg', encoding='utf-8').read().format(pop_size=pop_size, compat_threshold=compat_threshold)
    config_path = f'sweep_results/run_{sweep_run_id}_cfg.cfg'
    with open(config_path, 'w', encoding='utf-8') as f:
        f.write(generated_cfg_text)
    replicate_runs = []
    for random_seed in [0, 1]:
        random.seed(random_seed)
        fitness, species_counts = run_single_experiment(config_path, max_generations=20)
        replicate_runs.append({'fitness': fitness, 'species_counts': species_counts})
    best_fitnesses = [r['fitness'] for r in replicate_runs]
    avg_species_last_gen = statistics.mean([r['species_counts'][-1] for r in replicate_runs]) if all(r['species_counts'] for r in replicate_runs) else 0
    sweep_results.append({
        'pop_size': pop_size,
        'compat_threshold': compat_threshold,
        'best_fitness_mean': statistics.mean(best_fitnesses),
        'best_fitness_stdev': statistics.pstdev(best_fitnesses),
        'avg_species_last_gen': avg_species_last_gen,
    })
    print(f"[population={pop_size}, compat={compat_threshold}] best_mean={statistics.mean(best_fitnesses):.3f} ± {statistics.pstdev(best_fitnesses):.3f}; species@last={avg_species_last_gen:.1f}")
with open('sweep_results/sweep_summary.json', 'w', encoding='utf-8') as f:
    json.dump(sweep_results, f, indent=2)
print('\nSaved sweep_results to sweep_results/sweep_summary.json')

## CartPole evaluation and run

The CartPole section evaluates genomes on the `CartPole-v1` environment using gymnasium.

What it does
- Builds a NEAT config from `cartpole_config.cfg` and applies quick overrides from the CartPole constants (`CART_POP_SIZE`, `CART_ACTIVATION_*`).
- Uses a simple discrete policy: the network outputs a scalar; we take action 1 if output > 0.0, else 0.
- Evaluates each genome by averaging episode returns over multiple episodes in `compute_genome_cartpole_average_return` (higher is better).
- Evolves a population with `population.run(evaluate_population_cartpole, n=CART_POP_SIZE)` and reports the winner’s average fitness.
- Adds reporters for console logs and periodic checkpoints under `sweep_results/cartpole-*`.

Outputs
- Checkpoints every 5 generations in `sweep_results/cartpole-*` (for resuming or inspection).
- Console logs of generation progress via `StdOutReporter`.

Customize it
- Population and activations: edit `CART_POP_SIZE`, `CART_ACTIVATION_DEFAULT`, and `CART_ACTIVATION_OPTIONS`.
- Generations: change the `n` argument in the run call (default 20).
- Evaluation budget: adjust `episodes` and `max_steps` in `compute_genome_cartpole_average_return` for speed vs. stability.
- Policy: replace `choose_cartpole_action` with a different mapping from network output to discrete actions.
- Environment: try other Gymnasium tasks by changing `gym.make('CartPole-v1')` and matching input/output sizes in the config.

Notes
- CartPole is stochastic; averaging across multiple episodes yields more stable fitness estimates.
- If you see unstable training, try larger populations or tune the species compatibility threshold in the NEAT config.

In [None]:
CART_POP_SIZE = 50
CART_ACTIVATION_DEFAULT = 'sigmoid'
CART_ACTIVATION_OPTIONS = ['sigmoid'] # ['tanh', 'sigmoid', 'relu']

## CartPole Main Code

In [None]:
def choose_cartpole_action(network, observation):
    network_output = network.activate(observation)[0]
    return 1 if network_output > 0.0 else 0

def compute_genome_cartpole_average_return(genome, config, episodes=2, max_steps=500):
    env = gym.make('CartPole-v1')
    network = FeedForwardNetwork.create(genome, config)
    total_return = 0.0
    for _ in range(episodes):
        observation, _ = env.reset()
        episode_return = 0.0
        for _ in range(max_steps):
            action = choose_cartpole_action(network, observation)
            observation, r, terminated, truncated, _ = env.step(action)
            episode_return += float(r)
            if terminated or truncated:
                break
        total_return += episode_return
    env.close()
    return total_return / episodes

def evaluate_population_cartpole(genomes, config):
    for _, genome in genomes:
        genome.fitness = compute_genome_cartpole_average_return(genome, config)


# Run CartPole (short test)
config_path = 'cartpole_config.cfg'
if not os.path.exists(config_path):
    raise FileNotFoundError(f"{config_path} not found. Add it to the code folder.")

config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, neat.DefaultSpeciesSet, neat.DefaultStagnation, config_path)
config.pop_size = CART_POP_SIZE
config.genome_config.activation_default = CART_ACTIVATION_DEFAULT
config.genome_config.activation_options = CART_ACTIVATION_OPTIONS
population = neat.Population(config)

population.add_reporter(neat.StdOutReporter(True))
statistics_reporter = neat.StatisticsReporter()
population.add_reporter(statistics_reporter)
os.makedirs('sweep_results', exist_ok=True)
population.add_reporter(neat.Checkpointer(5, filename_prefix='sweep_results/cartpole-'))
best_genome = population.run(evaluate_population_cartpole, n=CART_POP_SIZE)
if best_genome is None:
    print('\nNo best_genome produced by CartPole population run')
else:
    print('\nWinner avg fitness:', compute_genome_cartpole_average_return(best_genome, config, episodes=3))