# Problem-Specific Coordinate Generation for HyperNEAT Substrates

## Setup

The testing was done on a 16GB VRAM GPU with CUDA 12.8. VRAM usage is determined mainly by substrate and population size.

### Imports

This setup requires some dependencies, mainly TensorNEAT, JAX, numpy, matplotlib, NetworkX, scikit-learn and wandb for logging. Using a virtual environment (i.e. conda) is highly recommended. Python 3.10.18 was used in development and testing.

In [1]:
import jax
import numpy as np
from collections import defaultdict
import pickle

from config import config
from substrate_generation.data_sampling import collect_random_policy_data, collect_trained_agent_policy_data
from evol_pipeline.brax_env import CustomBraxEnv
from utils.utils import setup_folder_structure

A quick setup of the folder structure to avoid errors further down the line.

In [2]:
OUTPUT_DIR = config["experiment"]["output_dir"]
setup_folder_structure(OUTPUT_DIR)

### Setup Environment

[Brax environments](https://github.com/google/brax/tree/main/brax/envs) are used for this experiment through the [TensorNEAT wrapper](https://github.com/EMI-Group/tensorneat/tree/main/src/tensorneat/problem/rl).

In [3]:
env_name = config["experiment"]["env_name"]
env_problem = CustomBraxEnv(
    env_name=env_name,
    backend=config["environment"]["backend"],
    brax_args=config["environment"]["brax_args"],
    max_step=config["environment"]["max_step"],
    repeat_times=config["environment"]["repeat_times"],
    obs_normalization=False,
    sample_episodes=16,
)
obs_size = env_problem.input_shape[0]
act_size = env_problem.output_shape[0]

print("env_problem.input_shape: ", env_problem.input_shape)
print("env_problem.input_shape: ", env_problem.output_shape)

2025-09-07 21:17:09.659398: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-09-07 21:17:20.697392: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


env_problem.input_shape:  (27,)
env_problem.input_shape:  (8,)


## Data Sampling

Before we can apply data-driven approaches to the task of substrate generation, some data needs to be generated. This can either be done by generating random action or actually training an agent for a short period and then sampling from its action. Both data sampling methods are used for comparison.

### Data Sampling Method 1: Trained Agent

First, we train an number of agents for a few generations on a simple substrate for the sole purpose of sampling their actions.


In [4]:
key = jax.random.PRNGKey(config["experiment"]["seed"]) # Use seed from config
num_trained_agent_sampling = config["data_sampling"]["num_agents_to_sample"]
all_trained_agent_data = []
for i in range(num_trained_agent_sampling):
    print(f"\nTraining and sampling agent #{i+1}/{num_trained_agent_sampling}")
    key, subkey = jax.random.split(key)
    trained_agent_data_run = collect_trained_agent_policy_data(
        env_problem=env_problem,
        key=subkey,
        num_steps=config["data_sampling"]["sampling_steps"],
        training_config=config["data_sampling"]["trained_agent_sampling"]
    )
    all_trained_agent_data.append(trained_agent_data_run)

combined_trained_agent_data = np.vstack(all_trained_agent_data)

print(f"Data collection complete. Combined data from {num_trained_agent_sampling} trained agents.")
print(f"Shape of each run's data: {all_trained_agent_data[0].shape}")
print(f"Shape of combined trained agents data: {combined_trained_agent_data.shape}")



Training and sampling agent #1/2

Starting Agent Training and Data Collection
Configuring and training the agent...
Query dimension for sampling:  4
initializing
initializing finished
start compile
compile finished, cost time: 32.839402s
Generation: 1, Cost time: 6732.27ms
 	fitness: valid cnt: 1000, max: 103.9364, min: -313.2156, mean: -209.9459, std: 94.1488

	node counts: max: 7, min: 5, mean: 6.06
 	conn counts: max: 7, min: 0, mean: 4.62
 	species: 20, [454, 2, 62, 1, 2, 4, 2, 298, 1, 5, 1, 5, 1, 4, 1, 3, 1, 1, 6, 146]

Generation: 2, Cost time: 6652.57ms
 	fitness: valid cnt: 1000, max: 102.6933, min: -314.9031, mean: -111.4713, std: 147.4896

	node counts: max: 8, min: 5, mean: 6.20
 	conn counts: max: 8, min: 0, mean: 4.16
 	species: 20, [120, 144, 312, 51, 9, 24, 30, 3, 25, 21, 1, 3, 15, 12, 11, 7, 6, 3, 1, 202]

Generation: 3, Cost time: 6680.04ms
 	fitness: valid cnt: 1000, max: 104.4815, min: -315.2883, mean: -20.1863, std: 148.7720

	node counts: max: 9, min: 5, mean: 6.3

2025-09-07 21:22:03.735156: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-09-07 21:22:03.735172: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


Causal expert data collection finished.


Training and sampling agent #2/2

Starting Agent Training and Data Collection
Configuring and training the agent...
Query dimension for sampling:  4
initializing
initializing finished
start compile
compile finished, cost time: 33.997189s
Generation: 1, Cost time: 6965.30ms
 	fitness: valid cnt: 1000, max: 101.6768, min: -314.9246, mean: -206.7634, std: 96.8780

	node counts: max: 7, min: 5, mean: 6.05
 	conn counts: max: 7, min: 0, mean: 4.60
 	species: 20, [517, 4, 88, 19, 73, 38, 1, 106, 2, 1, 2, 1, 7, 1, 2, 19, 1, 2, 3, 113]

Generation: 2, Cost time: 6978.18ms
 	fitness: valid cnt: 1000, max: 102.6004, min: -325.9275, mean: -107.7709, std: 147.4068

	node counts: max: 8, min: 5, mean: 6.15
 	conn counts: max: 8, min: 0, mean: 4.30
 	species: 20, [100, 43, 75, 1, 110, 84, 188, 41, 26, 21, 21, 79, 13, 13, 4, 8, 8, 3, 3, 159]

Generation: 3, Cost time: 6921.65ms
 	fitness: valid cnt: 1000, max: 102.6053, min: -309.0288, mean: 5.7238, std: 137.

### Data Sampling Method 2: Random Policy

Now, we will sample data from the environment with a random action policy.

In [5]:
key, random_key = jax.random.split(key)
num_random_sampling_steps = num_trained_agent_sampling * config["data_sampling"]["sampling_steps"]
random_data = collect_random_policy_data(env_problem, random_key, num_random_sampling_steps)

Starting data collection for 2000 steps using a random policy...
Causal data collection finished.


In [6]:
data_sources = {
    "trained": combined_trained_agent_data,
    "random": random_data
}

analysis_io_coors = defaultdict(lambda: defaultdict(dict))

In [8]:
with open(f"{OUTPUT_DIR}/data_sources.pkl", "wb") as f:
    pickle.dump(data_sources, f)
