diff --git a/genome.py b/genome.py index d3f905b..56e5c3c 100644 --- a/genome.py +++ b/genome.py @@ -4,6 +4,8 @@ from random import choice, random, shuffle from typing import Dict, List, Tuple +import torch + from neat.aggregations import AggregationFunctionSet from neat.config import ConfigParameter, write_pretty_params from neat.graphs import creates_cycle, required_for_output @@ -11,7 +13,7 @@ from attributes import BoolAttribute, FloatAttribute, IntAttribute, StringAttribute from computation_graphs.functions.activation import * from computation_graphs.functions.aggregation import * -from genes import ConnectionGene, NodeGene +from genes import ConnectionGene, NodeGene, NODE_TYPE_TO_INDEX class OptimizerGenomeConfig(object): @@ -191,6 +193,8 @@ def configure_crossover(self, genome1, genome2, config): # Homologous gene: combine genes from both parents. self.nodes[key] = ng1.crossover(ng2) + + def mutate(self, config): """Mutates this genome.""" @@ -409,6 +413,42 @@ def get_pruned_copy(self, genome_config): new_genome.connections = used_connection_genes return new_genome + def compile_optimizer(self, genome_config): + """Compile this genome into a TorchScript optimizer.""" + from graph_builder import rebuild_and_script + + if self.graph_dict is None: + node_ids = sorted(self.nodes.keys()) + node_types = [] + node_attributes = [] + for nid in node_ids: + node = self.nodes[nid] + idx = NODE_TYPE_TO_INDEX.get(node.node_type) + if idx is None: + raise KeyError(f"Unknown node_type {node.node_type!r}") + node_types.append(idx) + node_attributes.append(node.dynamic_attributes) + node_types = torch.tensor(node_types, dtype=torch.long) + + edges = [] + for (src, dst), conn in self.connections.items(): + if conn.enabled and src in node_ids and dst in node_ids: + local_src = node_ids.index(src) + local_dst = node_ids.index(dst) + edges.append([local_src, local_dst]) + if edges: + edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() + else: + edge_index = torch.empty((2, 0), dtype=torch.long) + self.graph_dict = { + "node_types": node_types, + "edge_index": edge_index, + "node_attributes": node_attributes, + } + + self.optimizer = rebuild_and_script(self.graph_dict, genome_config, key=self.key) + self.optimizer_path = None + def add_node(self, node_type: str, activation, aggregation) -> NodeGene: if activation is None and aggregation is None: print("WARNING: node added without any operation") diff --git a/reproduction.py b/reproduction.py index 539ee3b..3a692b2 100644 --- a/reproduction.py +++ b/reproduction.py @@ -103,6 +103,8 @@ def reproduce(self, config, species, pop_size, generation, task): child = config.genome_type(cid) child.configure_crossover(p1, p2, config.genome_config) child.mutate(config.genome_config) + if hasattr(child, "compile_optimizer"): + child.compile_optimizer(config.genome_config) new_population[cid] = child self.ancestors[cid] = (p1_id, p2_id)