Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from random import choice, random, shuffle
from typing import Dict, List, Tuple

import torch

from neat.aggregations import AggregationFunctionSet
from neat.config import ConfigParameter, write_pretty_params
from neat.graphs import creates_cycle, required_for_output

from attributes import BoolAttribute, FloatAttribute, IntAttribute, StringAttribute
from computation_graphs.functions.activation import *
from computation_graphs.functions.aggregation import *
from genes import ConnectionGene, NodeGene
from genes import ConnectionGene, NodeGene, NODE_TYPE_TO_INDEX


class OptimizerGenomeConfig(object):
Expand Down Expand Up @@ -191,6 +193,8 @@ def configure_crossover(self, genome1, genome2, config):
# Homologous gene: combine genes from both parents.
self.nodes[key] = ng1.crossover(ng2)



def mutate(self, config):
"""Mutates this genome."""

Expand Down Expand Up @@ -409,6 +413,42 @@ def get_pruned_copy(self, genome_config):
new_genome.connections = used_connection_genes
return new_genome

def compile_optimizer(self, genome_config):
"""Compile this genome into a TorchScript optimizer."""
from graph_builder import rebuild_and_script

if self.graph_dict is None:
node_ids = sorted(self.nodes.keys())
node_types = []
node_attributes = []
for nid in node_ids:
node = self.nodes[nid]
idx = NODE_TYPE_TO_INDEX.get(node.node_type)
if idx is None:
raise KeyError(f"Unknown node_type {node.node_type!r}")
node_types.append(idx)
node_attributes.append(node.dynamic_attributes)
node_types = torch.tensor(node_types, dtype=torch.long)

edges = []
for (src, dst), conn in self.connections.items():
if conn.enabled and src in node_ids and dst in node_ids:
local_src = node_ids.index(src)
local_dst = node_ids.index(dst)
edges.append([local_src, local_dst])
if edges:
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
else:
edge_index = torch.empty((2, 0), dtype=torch.long)
self.graph_dict = {
"node_types": node_types,
"edge_index": edge_index,
"node_attributes": node_attributes,
}

self.optimizer = rebuild_and_script(self.graph_dict, genome_config, key=self.key)
self.optimizer_path = None

def add_node(self, node_type: str, activation, aggregation) -> NodeGene:
if activation is None and aggregation is None:
print("WARNING: node added without any operation")
Expand Down
2 changes: 2 additions & 0 deletions reproduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def reproduce(self, config, species, pop_size, generation, task):
child = config.genome_type(cid)
child.configure_crossover(p1, p2, config.genome_config)
child.mutate(config.genome_config)
if hasattr(child, "compile_optimizer"):
child.compile_optimizer(config.genome_config)
new_population[cid] = child
self.ancestors[cid] = (p1_id, p2_id)

Expand Down