Skip to content

Commit

Permalink
Added IntegerAttribute.
Browse files Browse the repository at this point in the history
Implement attribute validation mechanism.
Use f-strings instead of str.format.
Check that string attribute default value is in the list of options.
Remove deprecated ConfigParser API usage.
Factor genome pruning mechanism into a module-level function to allow use with other genome types.
  • Loading branch information
CodeReclaimers committed May 1, 2022
1 parent 85e79ab commit 2371fcd
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 79 deletions.
73 changes: 58 additions & 15 deletions neat/attributes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Deals with the attributes (variable parameters) of genes"""
from random import choice, gauss, random, uniform
from random import choice, gauss, random, uniform, randint

from neat.config import ConfigParameter

Expand All @@ -18,13 +18,11 @@ def __init__(self, name, **default_dict):
setattr(self, n + "_name", self.config_item_name(n))

def config_item_name(self, config_item_base_name):
return "{0}_{1}".format(self.name, config_item_base_name)
return f"{self.name}_{config_item_base_name}"

def get_config_params(self):
return [ConfigParameter(self.config_item_name(n),
self._config_items[n][0],
self._config_items[n][1])
for n in self._config_items]
return [ConfigParameter(self.config_item_name(n), ci[0], ci[1])
for n, ci in self._config_items.items()]


class FloatAttribute(BaseAttribute):
Expand Down Expand Up @@ -61,9 +59,7 @@ def init_value(self, config):
(mean + (2 * stdev)))
return uniform(min_value, max_value)

raise RuntimeError("Unknown init_type {!r} for {!s}".format(getattr(config,
self.init_type_name),
self.init_type_name))
raise RuntimeError(f"Unknown init_type {getattr(config, self.init_type_name)!r} for {self.init_type_name!s}")

def mutate_value(self, value, config):
# mutate_rate is usually no lower than replace_rate, and frequently higher -
Expand All @@ -82,7 +78,49 @@ def mutate_value(self, value, config):

return value

def validate(self, config): # pragma: no cover
def validate(self, config):
pass


class IntegerAttribute(BaseAttribute):
"""
Class for numeric attributes,
such as the response of a node or the weight of a connection.
"""
_config_items = {"replace_rate": [float, None],
"mutate_rate": [float, None],
"mutate_power": [float, None],
"max_value": [float, None],
"min_value": [float, None]}

def clamp(self, value, config):
min_value = getattr(config, self.min_value_name)
max_value = getattr(config, self.max_value_name)
return max(min(value, max_value), min_value)

def init_value(self, config):
min_value = getattr(config, self.min_value_name)
max_value = getattr(config, self.max_value_name)
return randint(min_value, max_value)

def mutate_value(self, value, config):
# mutate_rate is usually no lower than replace_rate, and frequently higher -
# so put first for efficiency
mutate_rate = getattr(config, self.mutate_rate_name)

r = random()
if r < mutate_rate:
mutate_power = getattr(config, self.mutate_power_name)
return self.clamp(value + int(round(gauss(0.0, mutate_power))), config)

replace_rate = getattr(config, self.replace_rate_name)

if r < replace_rate + mutate_rate:
return self.init_value(config)

return value

def validate(self, config):
pass


Expand All @@ -103,8 +141,7 @@ def init_value(self, config):
elif default in ('random', 'none'):
return bool(random() < 0.5)

raise RuntimeError("Unknown default value {!r} for {!s}".format(default,
self.name))
raise RuntimeError(f"Unknown default value {default!r} for {self.name!s}")

def mutate_value(self, value, config):
mutate_rate = getattr(config, self.mutate_rate_name)
Expand All @@ -125,7 +162,7 @@ def mutate_value(self, value, config):

return value

def validate(self, config): # pragma: no cover
def validate(self, config):
pass


Expand Down Expand Up @@ -158,5 +195,11 @@ def mutate_value(self, value, config):

return value

def validate(self, config): # pragma: no cover
pass
def validate(self, config):
default = getattr(config, self.default_name)
if default not in ('none', 'random'):
options = getattr(config, self.options_name)
if default not in options:
raise RuntimeError(f'Invalid activation function name: {default}')
assert default in options

44 changes: 17 additions & 27 deletions neat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@ def __init__(self, name, value_type, default=None):

def __repr__(self):
if self.default is None:
return "ConfigParameter({!r}, {!r})".format(self.name,
self.value_type)
return "ConfigParameter({!r}, {!r}, {!r})".format(self.name,
self.value_type,
self.default)
return f"ConfigParameter({self.name!r}, {self.value_type!r})"
return f"ConfigParameter({self.name!r}, {self.value_type!r}, {self.default!r})"

def parse(self, section, config_parser):
if int == self.value_type:
Expand All @@ -34,8 +31,7 @@ def parse(self, section, config_parser):
if str == self.value_type:
return config_parser.get(section, self.name)

raise RuntimeError("Unexpected configuration type: "
+ repr(self.value_type))
raise RuntimeError(f"Unexpected configuration type: {self.value_type!r}")

def interpret(self, config_dict):
"""
Expand All @@ -47,8 +43,7 @@ def interpret(self, config_dict):
if self.default is None:
raise RuntimeError('Missing configuration item: ' + self.name)
else:
warnings.warn("Using default {!r} for '{!s}'".format(self.default, self.name),
DeprecationWarning)
warnings.warn(f"Using default {self.default!r} for '{self.name!s}'", DeprecationWarning)
if (str != self.value_type) and isinstance(self.default, self.value_type):
return self.default
else:
Expand All @@ -71,8 +66,8 @@ def interpret(self, config_dict):
if list == self.value_type:
return value.split(" ")
except Exception:
raise RuntimeError("Error interpreting config item '{}' with value {!r} and type {}".format(
self.name, value, self.value_type))
raise RuntimeError(
f"Error interpreting config item '{self.name}' with value {value!r} and type {self.value_type}")

raise RuntimeError("Unexpected configuration type: " + repr(self.value_type))

Expand All @@ -90,7 +85,7 @@ def write_pretty_params(f, config, params):

for name in param_names:
p = params[name]
f.write('{} = {}\n'.format(p.name.ljust(longest_name), p.format(getattr(config, p.name))))
f.write(f'{p.name.ljust(longest_name)} = {p.format(getattr(config, p.name))}\n')


class UnknownConfigItemError(NameError):
Expand All @@ -115,7 +110,7 @@ def __init__(self, param_dict, param_list):
if len(unknown_list) > 1:
raise UnknownConfigItemError("Unknown configuration items:\n" +
"\n\t".join(unknown_list))
raise UnknownConfigItemError("Unknown configuration item {!s}".format(unknown_list[0]))
raise UnknownConfigItemError(f"Unknown configuration item {unknown_list[0]!s}")

@classmethod
def write_config(cls, f, config):
Expand All @@ -124,7 +119,7 @@ def write_config(cls, f, config):


class Config(object):
"""A simple container for user-configurable parameters of NEAT."""
"""A container for user-configurable parameters of NEAT."""

__params = [ConfigParameter('pop_size', int),
ConfigParameter('fitness_criterion', str),
Expand All @@ -149,10 +144,7 @@ def __init__(self, genome_type, reproduction_type, species_set_type, stagnation_

parameters = ConfigParser()
with open(filename) as f:
if hasattr(parameters, 'read_file'):
parameters.read_file(f)
else:
parameters.readfp(f)
parameters.read_file(f)

# NEAT configuration
if not parameters.has_section('NEAT'):
Expand All @@ -167,17 +159,15 @@ def __init__(self, genome_type, reproduction_type, species_set_type, stagnation_
setattr(self, p.name, p.parse('NEAT', parameters))
except Exception:
setattr(self, p.name, p.default)
warnings.warn("Using default {!r} for '{!s}'".format(p.default, p.name),
warnings.warn(f"Using default {p.default!r} for '{p.name!s}'",
DeprecationWarning)
param_list_names.append(p.name)
param_dict = dict(parameters.items('NEAT'))
unknown_list = [x for x in param_dict if x not in param_list_names]
if unknown_list:
if len(unknown_list) > 1:
raise UnknownConfigItemError("Unknown (section 'NEAT') configuration items:\n" +
"\n\t".join(unknown_list))
raise UnknownConfigItemError(
"Unknown (section 'NEAT') configuration item {!s}".format(unknown_list[0]))
raise UnknownConfigItemError("Unknown (section 'NEAT') configuration items:\n" + "\n\t".join(unknown_list))
raise UnknownConfigItemError(f"Unknown (section 'NEAT') configuration item {unknown_list[0]!s}")

# Parse type sections.
genome_dict = dict(parameters.items(genome_type.__name__))
Expand All @@ -199,14 +189,14 @@ def save(self, filename):
f.write('[NEAT]\n')
write_pretty_params(f, self, self.__params)

f.write('\n[{0}]\n'.format(self.genome_type.__name__))
f.write(f'\n[{self.genome_type.__name__}]\n')
self.genome_type.write_config(f, self.genome_config)

f.write('\n[{0}]\n'.format(self.species_set_type.__name__))
f.write(f'\n[{self.species_set_type.__name__}]\n')
self.species_set_type.write_config(f, self.species_set_config)

f.write('\n[{0}]\n'.format(self.stagnation_type.__name__))
f.write(f'\n[{self.stagnation_type.__name__}]\n')
self.stagnation_type.write_config(f, self.stagnation_config)

f.write('\n[{0}]\n'.format(self.reproduction_type.__name__))
f.write(f'\n[{self.reproduction_type.__name__}]\n')
self.reproduction_type.write_config(f, self.reproduction_config)
22 changes: 13 additions & 9 deletions neat/genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ def __init__(self, key):

def __str__(self):
attrib = ['key'] + [a.name for a in self._gene_attributes]
attrib = ['{0}={1}'.format(a, getattr(self, a)) for a in attrib]
return '{0}({1})'.format(self.__class__.__name__, ", ".join(attrib))
attrib = [f'{a}={getattr(self, a)}' for a in attrib]
return f'{self.__class__.__name__}({", ".join(attrib)})'

def __lt__(self, other):
assert isinstance(self.key, type(other.key)), "Cannot compare keys {0!r} and {1!r}".format(self.key, other.key)
assert isinstance(self.key, type(other.key)), f"Cannot compare keys {self.key!r} and {other.key!r}"
return self.key < other.key

@classmethod
Expand All @@ -37,13 +37,17 @@ def get_config_params(cls):
if not hasattr(cls, '_gene_attributes'):
setattr(cls, '_gene_attributes', getattr(cls, '__gene_attributes__'))
warnings.warn(
"Class '{!s}' {!r} needs '_gene_attributes' not '__gene_attributes__'".format(
cls.__name__, cls),
f"Class '{cls.__name__!s}' {cls!r} needs '_gene_attributes' not '__gene_attributes__'",
DeprecationWarning)
for a in cls._gene_attributes:
params += a.get_config_params()
return params

@classmethod
def validate_attributes(cls, config):
for a in cls._gene_attributes:
a.validate(config)

def init_attributes(self, config):
for a in self._gene_attributes:
setattr(self, a.name, a.init_value(config))
Expand Down Expand Up @@ -82,11 +86,11 @@ def crossover(self, gene2):
class DefaultNodeGene(BaseGene):
_gene_attributes = [FloatAttribute('bias'),
FloatAttribute('response'),
StringAttribute('activation', options='sigmoid'),
StringAttribute('aggregation', options='sum')]
StringAttribute('activation', options=''),
StringAttribute('aggregation', options='')]

def __init__(self, key):
assert isinstance(key, int), "DefaultNodeGene key must be an int, not {!r}".format(key)
assert isinstance(key, int), f"DefaultNodeGene key must be an int, not {key!r}"
BaseGene.__init__(self, key)

def distance(self, other, config):
Expand All @@ -109,7 +113,7 @@ class DefaultConnectionGene(BaseGene):
BoolAttribute('enabled')]

def __init__(self, key):
assert isinstance(key, tuple), "DefaultConnectionGene key must be a tuple, not {!r}".format(key)
assert isinstance(key, tuple), f"DefaultConnectionGene key must be a tuple, not {key!r}"
BaseGene.__init__(self, key)

def distance(self, other, config):
Expand Down
41 changes: 27 additions & 14 deletions neat/genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def __init__(self, params):
for p in self._params:
setattr(self, p.name, p.interpret(params))

self.node_gene_type.validate_attributes(self)
self.connection_gene_type.validate_attributes(self)

# By convention, input pins have negative keys, and the output
# pins have keys 0,1,...
self.input_keys = [-i - 1 for i in range(self.num_inputs)]
Expand Down Expand Up @@ -540,7 +543,8 @@ def connect_full_direct(self, config):
def connect_partial_nodirect(self, config):
"""
Create a partially-connected genome,
with (unless no hidden nodes) no direct input-output connections."""
with (unless no hidden nodes) no direct input-output connections.
"""
assert 0 <= config.connection_fraction <= 1
all_connections = self.compute_full_connections(config, False)
shuffle(all_connections)
Expand All @@ -563,19 +567,28 @@ def connect_partial_direct(self, config):
self.connections[connection.key] = connection

def get_pruned_copy(self, genome_config):
used_nodes = required_for_output(genome_config.input_keys, genome_config.output_keys, self.connections)
used_pins = used_nodes.union(genome_config.input_keys)

# Copy used nodes into a new genome.
used_node_genes, used_connection_genes = get_pruned_genes(self.nodes, self.connections,
genome_config.input_keys, genome_config.output_keys)
new_genome = DefaultGenome(None)
for n in used_nodes:
if n in self.nodes:
new_genome.nodes[n] = copy.deepcopy(self.nodes[n])
new_genome.nodes = used_node_genes
new_genome.connections = used_connection_genes
return new_genome

# Copy enabled and used connections into the new genome.
for key, cg in self.connections.items():
in_node_id, out_node_id = key
if cg.enabled and in_node_id in used_pins and out_node_id in used_pins:
new_genome.connections[key] = copy.deepcopy(cg)

return new_genome
def get_pruned_genes(node_genes, connection_genes, input_keys, output_keys):
used_nodes = required_for_output(input_keys, output_keys, connection_genes)
used_pins = used_nodes.union(input_keys)

# Copy used nodes into a new genome.
used_node_genes = {}
for n in used_nodes:
used_node_genes[n] = copy.deepcopy(node_genes[n])

# Copy enabled and used connections into the new genome.
used_connection_genes = {}
for key, cg in connection_genes.items():
in_node_id, out_node_id = key
if cg.enabled and in_node_id in used_pins and out_node_id in used_pins:
used_connection_genes[key] = copy.deepcopy(cg)

return used_node_genes, used_connection_genes
Loading

0 comments on commit 2371fcd

Please sign in to comment.