# Problem-Specific Coordinate Generation for HyperNEAT Substrates

# Part 1: Data 

In this first part, the data is sampled for data dependent substrate generation methods.

## Setup

The testing was done on a 16GB VRAM GPU with CUDA 12.8. VRAM usage is determined mainly by substrate and population size.

### Imports

This setup requires some dependencies, mainly TensorNEAT, JAX, numpy, matplotlib, NetworkX, scikit-learn and wandb for logging. Using a virtual environment (i.e. conda) is highly recommended. Python 3.10.18 was used in development and testing.

In [1]:
import jax
import numpy as np

from config import config

from substrate_generation.data_sampling import collect_random_policy_data, collect_trained_agent_policy_data
from evol_pipeline.brax_env import CustomBraxEnv
from utils.utils import save_data_sources

### Setup Environment

[Brax environments](https://github.com/google/brax/tree/main/brax/envs) are used for this experiment through the [TensorNEAT wrapper](https://github.com/EMI-Group/tensorneat/tree/main/src/tensorneat/problem/rl).

In [2]:
env_name = config["experiment"]["env_name"]
env_problem = CustomBraxEnv(
    env_name=env_name,
    backend=config["environment"]["backend"],
    brax_args=config["environment"]["brax_args"],
    max_step=config["environment"]["max_step"],
    repeat_times=config["environment"]["repeat_times"],
    obs_normalization=config["environment"]["obs_normalization"],
    sample_episodes=config["environment"]["sample_episodes"],
)
obs_size = env_problem.input_shape[0]
act_size = env_problem.output_shape[0]
feature_dims_repeats = config["data_analysis"]["feature_dims"]

print("env_problem.input_shape: ", env_problem.input_shape)
print("env_problem.input_shape: ", env_problem.output_shape)

2025-11-03 10:55:11.503925: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-11-03 10:55:22.327066: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


env_problem.input_shape:  (27,)
env_problem.input_shape:  (8,)


## Data Sampling

Before we can apply data-driven approaches to the task of substrate generation, some data needs to be generated. This can either be done by generating random action or actually training an agent for a short period and then sampling from its action. Both data sampling methods are used for comparison.

### Data Sampling Method 1: Trained Agent

First, we train an number of agents for a few generations on a simple substrate for the sole purpose of sampling their actions.


In [3]:
key = jax.random.PRNGKey(config["experiment"]["seed"])
num_trained_agent_sampling = config["data_sampling"]["num_agents_to_sample"]
all_trained_agent_data = []
for i in range(num_trained_agent_sampling):
    print(f"\nTraining and sampling agent #{i+1}/{num_trained_agent_sampling}")
    key, subkey = jax.random.split(key)
    trained_agent_data_run = collect_trained_agent_policy_data(
        env_problem=env_problem,
        key=subkey,
        num_steps=config["data_sampling"]["sampling_steps"],
        obs_diff_only=True,
        do_normalization=True,
    )
    all_trained_agent_data.append(trained_agent_data_run)

combined_trained_agent_data = np.vstack(all_trained_agent_data)

print(f"Data collection complete. Combined data from {num_trained_agent_sampling} trained agents.")
print(f"Shape of each run's data: {all_trained_agent_data[0].shape}")
print(f"Shape of combined trained agents data: {combined_trained_agent_data.shape}")



Training and sampling agent #1/2

Starting Agent Training and Data Collection
Configuring and training the agent...
Query dimension for sampling:  4
initializing
initializing finished
start compile
compile finished, cost time: 32.275134s
Generation: 1, Cost time: 7079.28ms
 	fitness: valid cnt: 1000, max: 50.6148, min: -3955.6804, mean: -1613.9531, std: 1381.5681

	node counts: max: 8, min: 6, mean: 7.27
 	conn counts: max: 12, min: 4, mean: 10.35
 	species: 10, [122, 59, 59, 39, 14, 1, 1, 1, 1, 703]

Generation: 2, Cost time: 7042.62ms
 	fitness: valid cnt: 1000, max: 51.5624, min: -3956.4119, mean: -541.1389, std: 1094.5309

	node counts: max: 9, min: 6, mean: 7.45
 	conn counts: max: 14, min: 3, mean: 10.52
 	species: 10, [113, 74, 84, 23, 5, 52, 34, 26, 17, 572]

Generation: 3, Cost time: 7054.15ms
 	fitness: valid cnt: 1000, max: 52.8876, min: -3952.0601, mean: -146.7727, std: 662.6440

	node counts: max: 10, min: 5, mean: 7.66
 	conn counts: max: 14, min: 2, mean: 10.55
 	specie

2025-11-03 11:07:54.413346: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-11-03 11:07:54.413362: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


Causal expert data collection finished.


Training and sampling agent #2/2

Starting Agent Training and Data Collection
Configuring and training the agent...
Query dimension for sampling:  4
initializing
initializing finished
start compile
compile finished, cost time: 32.474618s
Generation: 1, Cost time: 7108.83ms
 	fitness: valid cnt: 1000, max: 50.9357, min: -3970.1001, mean: -1626.3590, std: 1422.9139

	node counts: max: 8, min: 6, mean: 7.25
 	conn counts: max: 12, min: 5, mean: 10.27
 	species: 10, [94, 96, 77, 1, 6, 1, 21, 2, 1, 701]

Generation: 2, Cost time: 6991.08ms
 	fitness: valid cnt: 1000, max: 53.4469, min: -3956.0967, mean: -584.2594, std: 1145.4045

	node counts: max: 9, min: 6, mean: 7.40
 	conn counts: max: 14, min: 5, mean: 10.45
 	species: 10, [19, 90, 119, 101, 39, 18, 35, 11, 17, 551]

Generation: 3, Cost time: 6966.30ms
 	fitness: valid cnt: 1000, max: 55.7937, min: -3955.3765, mean: -339.9572, std: 918.6163

	node counts: max: 10, min: 6, mean: 7.52
 	conn coun

### Data Sampling Method 2: Random Policy

Now, we will sample data from the environment with a random action policy.

In [4]:
key, random_key = jax.random.split(key)
num_random_sampling_steps = num_trained_agent_sampling * config["data_sampling"]["sampling_steps"]
random_data = collect_random_policy_data(
    env_problem, random_key, 
    num_random_sampling_steps, 
    obs_diff_only=True,
    do_normalization=True,
)

Starting data collection for 2000 steps using a random policy...
Causal data collection finished.


In [5]:
data_sources = {
    "trained": combined_trained_agent_data,
    "random": random_data
}

save_data_sources(data_sources, config["experiment"]["data_sources_path"])

Successfully saved data sources to: data_sources.npz
