From 5ddb3c3a969ce1009d72454dabb3b6dde7a89093 Mon Sep 17 00:00:00 2001 From: B Nova Date: Sun, 27 Jul 2025 15:22:36 -0400 Subject: [PATCH 1/2] Compile optimizer after crossover --- genome.py | 17 +++++++++++++++++ reproduction.py | 2 ++ 2 files changed, 19 insertions(+) diff --git a/genome.py b/genome.py index d3f905b..8b631e8 100644 --- a/genome.py +++ b/genome.py @@ -4,6 +4,8 @@ from random import choice, random, shuffle from typing import Dict, List, Tuple +import torch + from neat.aggregations import AggregationFunctionSet from neat.config import ConfigParameter, write_pretty_params from neat.graphs import creates_cycle, required_for_output @@ -191,6 +193,8 @@ def configure_crossover(self, genome1, genome2, config): # Homologous gene: combine genes from both parents. self.nodes[key] = ng1.crossover(ng2) + + def mutate(self, config): """Mutates this genome.""" @@ -409,6 +413,19 @@ def get_pruned_copy(self, genome_config): new_genome.connections = used_connection_genes return new_genome + def compile_optimizer(self, genome_config): + """Compile this genome into a TorchScript optimizer.""" + from graph_builder import DynamicOptimizerModule + + if not self.connections: + self.optimizer = None + else: + module = DynamicOptimizerModule( + self, genome_config.input_keys, genome_config.output_keys, self.graph_dict + ) + self.optimizer = torch.jit.script(module) + self.optimizer_path = None + def add_node(self, node_type: str, activation, aggregation) -> NodeGene: if activation is None and aggregation is None: print("WARNING: node added without any operation") diff --git a/reproduction.py b/reproduction.py index 539ee3b..3a692b2 100644 --- a/reproduction.py +++ b/reproduction.py @@ -103,6 +103,8 @@ def reproduce(self, config, species, pop_size, generation, task): child = config.genome_type(cid) child.configure_crossover(p1, p2, config.genome_config) child.mutate(config.genome_config) + if hasattr(child, "compile_optimizer"): + child.compile_optimizer(config.genome_config) new_population[cid] = child self.ancestors[cid] = (p1_id, p2_id) From ccacf871f7e7f51220751e9b9cd4fa188ea53d0f Mon Sep 17 00:00:00 2001 From: B Nova Date: Sun, 27 Jul 2025 15:34:14 -0400 Subject: [PATCH 2/2] Fix compile_optimizer to rebuild graph --- genome.py | 43 +++++++++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/genome.py b/genome.py index 8b631e8..56e5c3c 100644 --- a/genome.py +++ b/genome.py @@ -13,7 +13,7 @@ from attributes import BoolAttribute, FloatAttribute, IntAttribute, StringAttribute from computation_graphs.functions.activation import * from computation_graphs.functions.aggregation import * -from genes import ConnectionGene, NodeGene +from genes import ConnectionGene, NodeGene, NODE_TYPE_TO_INDEX class OptimizerGenomeConfig(object): @@ -415,15 +415,38 @@ def get_pruned_copy(self, genome_config): def compile_optimizer(self, genome_config): """Compile this genome into a TorchScript optimizer.""" - from graph_builder import DynamicOptimizerModule - - if not self.connections: - self.optimizer = None - else: - module = DynamicOptimizerModule( - self, genome_config.input_keys, genome_config.output_keys, self.graph_dict - ) - self.optimizer = torch.jit.script(module) + from graph_builder import rebuild_and_script + + if self.graph_dict is None: + node_ids = sorted(self.nodes.keys()) + node_types = [] + node_attributes = [] + for nid in node_ids: + node = self.nodes[nid] + idx = NODE_TYPE_TO_INDEX.get(node.node_type) + if idx is None: + raise KeyError(f"Unknown node_type {node.node_type!r}") + node_types.append(idx) + node_attributes.append(node.dynamic_attributes) + node_types = torch.tensor(node_types, dtype=torch.long) + + edges = [] + for (src, dst), conn in self.connections.items(): + if conn.enabled and src in node_ids and dst in node_ids: + local_src = node_ids.index(src) + local_dst = node_ids.index(dst) + edges.append([local_src, local_dst]) + if edges: + edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() + else: + edge_index = torch.empty((2, 0), dtype=torch.long) + self.graph_dict = { + "node_types": node_types, + "edge_index": edge_index, + "node_attributes": node_attributes, + } + + self.optimizer = rebuild_and_script(self.graph_dict, genome_config, key=self.key) self.optimizer_path = None def add_node(self, node_type: str, activation, aggregation) -> NodeGene: