Skip to content

Commit

Permalink
Refactor DefaultGenome.get_pruned_copy to use graphs.required_for_out…
Browse files Browse the repository at this point in the history
…put.

Switch to using f-strings.
Check that required_for_output input and output sets are disjoint.
Fix typo.
  • Loading branch information
CodeReclaimers committed May 1, 2022
1 parent 39d6719 commit 93fe32d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 29 deletions.
40 changes: 12 additions & 28 deletions neat/genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from neat.config import ConfigParameter, write_pretty_params
from neat.genes import DefaultConnectionGene, DefaultNodeGene
from neat.graphs import creates_cycle
from neat.graphs import required_for_output


class DefaultGenomeConfig(object):
Expand Down Expand Up @@ -76,8 +77,7 @@ def __init__(self, params):
elif self.structural_mutation_surer.lower() == 'default':
self.structural_mutation_surer = 'default'
else:
error_string = "Invalid structural_mutation_surer {!r}".format(
self.structural_mutation_surer)
error_string = f"Invalid structural_mutation_surer {self.structural_mutation_surer!r}"
raise RuntimeError(error_string)

self.node_indexer = None
Expand All @@ -93,10 +93,9 @@ def save(self, f):
if not (0 <= self.connection_fraction <= 1):
raise RuntimeError(
"'partial' connection value must be between 0.0 and 1.0, inclusive.")
f.write('initial_connection = {0} {1}\n'.format(self.initial_connection,
self.connection_fraction))
f.write(f'initial_connection = {self.initial_connection} {self.connection_fraction}\n')
else:
f.write('initial_connection = {0}\n'.format(self.initial_connection))
f.write(f'initial_connection = {self.initial_connection}\n')

assert self.initial_connection in self.allowed_connectivity

Expand Down Expand Up @@ -124,8 +123,7 @@ def check_structural_mutation_surer(self):
elif self.structural_mutation_surer == 'default':
return self.single_structural_mutation
else:
error_string = "Invalid structural_mutation_surer {!r}".format(
self.structural_mutation_surer)
error_string = f"Invalid structural_mutation_surer {self.structural_mutation_surer!r}"
raise RuntimeError(error_string)


Expand Down Expand Up @@ -224,10 +222,8 @@ def configure_new(self, config):
if config.num_hidden > 0:
print(
"Warning: initial_connection = partial with hidden nodes will not do direct input-output connections;",
"\tif this is desired, set initial_connection = partial_nodirect {0};".format(
config.connection_fraction),
"\tif not, set initial_connection = partial_direct {0}".format(
config.connection_fraction),
f"\tif this is desired, set initial_connection = partial_nodirect {config.connection_fraction};",
f"\tif not, set initial_connection = partial_direct {config.connection_fraction}",
sep='\n', file=sys.stderr)
self.connect_partial_nodirect(config)

Expand Down Expand Up @@ -452,9 +448,9 @@ def size(self):
return len(self.nodes), num_enabled_connections

def __str__(self):
s = "Key: {0}\nFitness: {1}\nNodes:".format(self.key, self.fitness)
s = f"Key: {self.key}\nFitness: {self.fitness}\nNodes:"
for k, ng in self.nodes.items():
s += "\n\t{0} {1!s}".format(k, ng)
s += f"\n\t{k} {ng!s}"
s += "\nConnections:"
connections = list(self.connections.values())
connections.sort()
Expand Down Expand Up @@ -567,20 +563,8 @@ def connect_partial_direct(self, config):
self.connections[connection.key] = connection

def get_pruned_copy(self, genome_config):
# Determine which nodes are connected via enabled connections to any output node.
used_nodes = set(genome_config.output_keys)
pending = set(genome_config.output_keys)
while pending:
new_pending = set()
for key, cg in self.connections.items():
if not cg.enabled:
continue

in_node_id, out_node_id = key
if out_node_id in pending and in_node_id not in used_nodes:
new_pending.add(in_node_id)
used_nodes.add(in_node_id)
pending = new_pending
used_nodes = required_for_output(genome_config.input_keys, genome_config.output_keys, self.connections)
used_pins = used_nodes.union(genome_config.input_keys)

# Copy used nodes into a new genome.
new_genome = DefaultGenome(None)
Expand All @@ -591,7 +575,7 @@ def get_pruned_copy(self, genome_config):
# Copy enabled and used connections into the new genome.
for key, cg in self.connections.items():
in_node_id, out_node_id = key
if cg.enabled and in_node_id in used_nodes and out_node_id in used_nodes:
if cg.enabled and in_node_id in used_pins and out_node_id in used_pins:
new_genome.connections[key] = copy.deepcopy(cg)

return new_genome
3 changes: 2 additions & 1 deletion neat/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ def required_for_output(inputs, outputs, connections):
Returns a set of identifiers of required nodes.
"""
assert not set(inputs).intersection(outputs)

required = set(outputs)
s = set(outputs)
while 1:
# Find nodes not in S whose output is consumed by a node in s.
# Find nodes not in s whose output is consumed by a node in s.
t = set(a for (a, b) in connections if b in s and a not in s)

if not t:
Expand Down

0 comments on commit 93fe32d

Please sign in to comment.