From 2371fcd2e334d8f603423b5e0774d63b7c12de9e Mon Sep 17 00:00:00 2001 From: CodeReclaimers Date: Sun, 1 May 2022 14:55:18 -0400 Subject: [PATCH] Added IntegerAttribute. Implement attribute validation mechanism. Use f-strings instead of str.format. Check that string attribute default value is in the list of options. Remove deprecated ConfigParser API usage. Factor genome pruning mechanism into a module-level function to allow use with other genome types. --- neat/attributes.py | 73 +++++++++++++++++++++++++++++-------- neat/config.py | 44 +++++++++------------- neat/genes.py | 22 ++++++----- neat/genome.py | 41 ++++++++++++++------- tests/bad_configuration0 | 79 ++++++++++++++++++++++++++++++++++++++++ tests/test_config.py | 18 ++------- 6 files changed, 198 insertions(+), 79 deletions(-) create mode 100644 tests/bad_configuration0 diff --git a/neat/attributes.py b/neat/attributes.py index 5f71fbbf..0853bbb7 100644 --- a/neat/attributes.py +++ b/neat/attributes.py @@ -1,5 +1,5 @@ """Deals with the attributes (variable parameters) of genes""" -from random import choice, gauss, random, uniform +from random import choice, gauss, random, uniform, randint from neat.config import ConfigParameter @@ -18,13 +18,11 @@ def __init__(self, name, **default_dict): setattr(self, n + "_name", self.config_item_name(n)) def config_item_name(self, config_item_base_name): - return "{0}_{1}".format(self.name, config_item_base_name) + return f"{self.name}_{config_item_base_name}" def get_config_params(self): - return [ConfigParameter(self.config_item_name(n), - self._config_items[n][0], - self._config_items[n][1]) - for n in self._config_items] + return [ConfigParameter(self.config_item_name(n), ci[0], ci[1]) + for n, ci in self._config_items.items()] class FloatAttribute(BaseAttribute): @@ -61,9 +59,7 @@ def init_value(self, config): (mean + (2 * stdev))) return uniform(min_value, max_value) - raise RuntimeError("Unknown init_type {!r} for {!s}".format(getattr(config, - self.init_type_name), - self.init_type_name)) + raise RuntimeError(f"Unknown init_type {getattr(config, self.init_type_name)!r} for {self.init_type_name!s}") def mutate_value(self, value, config): # mutate_rate is usually no lower than replace_rate, and frequently higher - @@ -82,7 +78,49 @@ def mutate_value(self, value, config): return value - def validate(self, config): # pragma: no cover + def validate(self, config): + pass + + +class IntegerAttribute(BaseAttribute): + """ + Class for numeric attributes, + such as the response of a node or the weight of a connection. + """ + _config_items = {"replace_rate": [float, None], + "mutate_rate": [float, None], + "mutate_power": [float, None], + "max_value": [float, None], + "min_value": [float, None]} + + def clamp(self, value, config): + min_value = getattr(config, self.min_value_name) + max_value = getattr(config, self.max_value_name) + return max(min(value, max_value), min_value) + + def init_value(self, config): + min_value = getattr(config, self.min_value_name) + max_value = getattr(config, self.max_value_name) + return randint(min_value, max_value) + + def mutate_value(self, value, config): + # mutate_rate is usually no lower than replace_rate, and frequently higher - + # so put first for efficiency + mutate_rate = getattr(config, self.mutate_rate_name) + + r = random() + if r < mutate_rate: + mutate_power = getattr(config, self.mutate_power_name) + return self.clamp(value + int(round(gauss(0.0, mutate_power))), config) + + replace_rate = getattr(config, self.replace_rate_name) + + if r < replace_rate + mutate_rate: + return self.init_value(config) + + return value + + def validate(self, config): pass @@ -103,8 +141,7 @@ def init_value(self, config): elif default in ('random', 'none'): return bool(random() < 0.5) - raise RuntimeError("Unknown default value {!r} for {!s}".format(default, - self.name)) + raise RuntimeError(f"Unknown default value {default!r} for {self.name!s}") def mutate_value(self, value, config): mutate_rate = getattr(config, self.mutate_rate_name) @@ -125,7 +162,7 @@ def mutate_value(self, value, config): return value - def validate(self, config): # pragma: no cover + def validate(self, config): pass @@ -158,5 +195,11 @@ def mutate_value(self, value, config): return value - def validate(self, config): # pragma: no cover - pass + def validate(self, config): + default = getattr(config, self.default_name) + if default not in ('none', 'random'): + options = getattr(config, self.options_name) + if default not in options: + raise RuntimeError(f'Invalid activation function name: {default}') + assert default in options + diff --git a/neat/config.py b/neat/config.py index 58a74dc9..7c17ad6f 100644 --- a/neat/config.py +++ b/neat/config.py @@ -15,11 +15,8 @@ def __init__(self, name, value_type, default=None): def __repr__(self): if self.default is None: - return "ConfigParameter({!r}, {!r})".format(self.name, - self.value_type) - return "ConfigParameter({!r}, {!r}, {!r})".format(self.name, - self.value_type, - self.default) + return f"ConfigParameter({self.name!r}, {self.value_type!r})" + return f"ConfigParameter({self.name!r}, {self.value_type!r}, {self.default!r})" def parse(self, section, config_parser): if int == self.value_type: @@ -34,8 +31,7 @@ def parse(self, section, config_parser): if str == self.value_type: return config_parser.get(section, self.name) - raise RuntimeError("Unexpected configuration type: " - + repr(self.value_type)) + raise RuntimeError(f"Unexpected configuration type: {self.value_type!r}") def interpret(self, config_dict): """ @@ -47,8 +43,7 @@ def interpret(self, config_dict): if self.default is None: raise RuntimeError('Missing configuration item: ' + self.name) else: - warnings.warn("Using default {!r} for '{!s}'".format(self.default, self.name), - DeprecationWarning) + warnings.warn(f"Using default {self.default!r} for '{self.name!s}'", DeprecationWarning) if (str != self.value_type) and isinstance(self.default, self.value_type): return self.default else: @@ -71,8 +66,8 @@ def interpret(self, config_dict): if list == self.value_type: return value.split(" ") except Exception: - raise RuntimeError("Error interpreting config item '{}' with value {!r} and type {}".format( - self.name, value, self.value_type)) + raise RuntimeError( + f"Error interpreting config item '{self.name}' with value {value!r} and type {self.value_type}") raise RuntimeError("Unexpected configuration type: " + repr(self.value_type)) @@ -90,7 +85,7 @@ def write_pretty_params(f, config, params): for name in param_names: p = params[name] - f.write('{} = {}\n'.format(p.name.ljust(longest_name), p.format(getattr(config, p.name)))) + f.write(f'{p.name.ljust(longest_name)} = {p.format(getattr(config, p.name))}\n') class UnknownConfigItemError(NameError): @@ -115,7 +110,7 @@ def __init__(self, param_dict, param_list): if len(unknown_list) > 1: raise UnknownConfigItemError("Unknown configuration items:\n" + "\n\t".join(unknown_list)) - raise UnknownConfigItemError("Unknown configuration item {!s}".format(unknown_list[0])) + raise UnknownConfigItemError(f"Unknown configuration item {unknown_list[0]!s}") @classmethod def write_config(cls, f, config): @@ -124,7 +119,7 @@ def write_config(cls, f, config): class Config(object): - """A simple container for user-configurable parameters of NEAT.""" + """A container for user-configurable parameters of NEAT.""" __params = [ConfigParameter('pop_size', int), ConfigParameter('fitness_criterion', str), @@ -149,10 +144,7 @@ def __init__(self, genome_type, reproduction_type, species_set_type, stagnation_ parameters = ConfigParser() with open(filename) as f: - if hasattr(parameters, 'read_file'): - parameters.read_file(f) - else: - parameters.readfp(f) + parameters.read_file(f) # NEAT configuration if not parameters.has_section('NEAT'): @@ -167,17 +159,15 @@ def __init__(self, genome_type, reproduction_type, species_set_type, stagnation_ setattr(self, p.name, p.parse('NEAT', parameters)) except Exception: setattr(self, p.name, p.default) - warnings.warn("Using default {!r} for '{!s}'".format(p.default, p.name), + warnings.warn(f"Using default {p.default!r} for '{p.name!s}'", DeprecationWarning) param_list_names.append(p.name) param_dict = dict(parameters.items('NEAT')) unknown_list = [x for x in param_dict if x not in param_list_names] if unknown_list: if len(unknown_list) > 1: - raise UnknownConfigItemError("Unknown (section 'NEAT') configuration items:\n" + - "\n\t".join(unknown_list)) - raise UnknownConfigItemError( - "Unknown (section 'NEAT') configuration item {!s}".format(unknown_list[0])) + raise UnknownConfigItemError("Unknown (section 'NEAT') configuration items:\n" + "\n\t".join(unknown_list)) + raise UnknownConfigItemError(f"Unknown (section 'NEAT') configuration item {unknown_list[0]!s}") # Parse type sections. genome_dict = dict(parameters.items(genome_type.__name__)) @@ -199,14 +189,14 @@ def save(self, filename): f.write('[NEAT]\n') write_pretty_params(f, self, self.__params) - f.write('\n[{0}]\n'.format(self.genome_type.__name__)) + f.write(f'\n[{self.genome_type.__name__}]\n') self.genome_type.write_config(f, self.genome_config) - f.write('\n[{0}]\n'.format(self.species_set_type.__name__)) + f.write(f'\n[{self.species_set_type.__name__}]\n') self.species_set_type.write_config(f, self.species_set_config) - f.write('\n[{0}]\n'.format(self.stagnation_type.__name__)) + f.write(f'\n[{self.stagnation_type.__name__}]\n') self.stagnation_type.write_config(f, self.stagnation_config) - f.write('\n[{0}]\n'.format(self.reproduction_type.__name__)) + f.write(f'\n[{self.reproduction_type.__name__}]\n') self.reproduction_type.write_config(f, self.reproduction_config) diff --git a/neat/genes.py b/neat/genes.py index cbc0ad7e..6b02840b 100644 --- a/neat/genes.py +++ b/neat/genes.py @@ -20,11 +20,11 @@ def __init__(self, key): def __str__(self): attrib = ['key'] + [a.name for a in self._gene_attributes] - attrib = ['{0}={1}'.format(a, getattr(self, a)) for a in attrib] - return '{0}({1})'.format(self.__class__.__name__, ", ".join(attrib)) + attrib = [f'{a}={getattr(self, a)}' for a in attrib] + return f'{self.__class__.__name__}({", ".join(attrib)})' def __lt__(self, other): - assert isinstance(self.key, type(other.key)), "Cannot compare keys {0!r} and {1!r}".format(self.key, other.key) + assert isinstance(self.key, type(other.key)), f"Cannot compare keys {self.key!r} and {other.key!r}" return self.key < other.key @classmethod @@ -37,13 +37,17 @@ def get_config_params(cls): if not hasattr(cls, '_gene_attributes'): setattr(cls, '_gene_attributes', getattr(cls, '__gene_attributes__')) warnings.warn( - "Class '{!s}' {!r} needs '_gene_attributes' not '__gene_attributes__'".format( - cls.__name__, cls), + f"Class '{cls.__name__!s}' {cls!r} needs '_gene_attributes' not '__gene_attributes__'", DeprecationWarning) for a in cls._gene_attributes: params += a.get_config_params() return params + @classmethod + def validate_attributes(cls, config): + for a in cls._gene_attributes: + a.validate(config) + def init_attributes(self, config): for a in self._gene_attributes: setattr(self, a.name, a.init_value(config)) @@ -82,11 +86,11 @@ def crossover(self, gene2): class DefaultNodeGene(BaseGene): _gene_attributes = [FloatAttribute('bias'), FloatAttribute('response'), - StringAttribute('activation', options='sigmoid'), - StringAttribute('aggregation', options='sum')] + StringAttribute('activation', options=''), + StringAttribute('aggregation', options='')] def __init__(self, key): - assert isinstance(key, int), "DefaultNodeGene key must be an int, not {!r}".format(key) + assert isinstance(key, int), f"DefaultNodeGene key must be an int, not {key!r}" BaseGene.__init__(self, key) def distance(self, other, config): @@ -109,7 +113,7 @@ class DefaultConnectionGene(BaseGene): BoolAttribute('enabled')] def __init__(self, key): - assert isinstance(key, tuple), "DefaultConnectionGene key must be a tuple, not {!r}".format(key) + assert isinstance(key, tuple), f"DefaultConnectionGene key must be a tuple, not {key!r}" BaseGene.__init__(self, key) def distance(self, other, config): diff --git a/neat/genome.py b/neat/genome.py index b46ccafc..2d652650 100644 --- a/neat/genome.py +++ b/neat/genome.py @@ -49,6 +49,9 @@ def __init__(self, params): for p in self._params: setattr(self, p.name, p.interpret(params)) + self.node_gene_type.validate_attributes(self) + self.connection_gene_type.validate_attributes(self) + # By convention, input pins have negative keys, and the output # pins have keys 0,1,... self.input_keys = [-i - 1 for i in range(self.num_inputs)] @@ -540,7 +543,8 @@ def connect_full_direct(self, config): def connect_partial_nodirect(self, config): """ Create a partially-connected genome, - with (unless no hidden nodes) no direct input-output connections.""" + with (unless no hidden nodes) no direct input-output connections. + """ assert 0 <= config.connection_fraction <= 1 all_connections = self.compute_full_connections(config, False) shuffle(all_connections) @@ -563,19 +567,28 @@ def connect_partial_direct(self, config): self.connections[connection.key] = connection def get_pruned_copy(self, genome_config): - used_nodes = required_for_output(genome_config.input_keys, genome_config.output_keys, self.connections) - used_pins = used_nodes.union(genome_config.input_keys) - - # Copy used nodes into a new genome. + used_node_genes, used_connection_genes = get_pruned_genes(self.nodes, self.connections, + genome_config.input_keys, genome_config.output_keys) new_genome = DefaultGenome(None) - for n in used_nodes: - if n in self.nodes: - new_genome.nodes[n] = copy.deepcopy(self.nodes[n]) + new_genome.nodes = used_node_genes + new_genome.connections = used_connection_genes + return new_genome - # Copy enabled and used connections into the new genome. - for key, cg in self.connections.items(): - in_node_id, out_node_id = key - if cg.enabled and in_node_id in used_pins and out_node_id in used_pins: - new_genome.connections[key] = copy.deepcopy(cg) - return new_genome +def get_pruned_genes(node_genes, connection_genes, input_keys, output_keys): + used_nodes = required_for_output(input_keys, output_keys, connection_genes) + used_pins = used_nodes.union(input_keys) + + # Copy used nodes into a new genome. + used_node_genes = {} + for n in used_nodes: + used_node_genes[n] = copy.deepcopy(node_genes[n]) + + # Copy enabled and used connections into the new genome. + used_connection_genes = {} + for key, cg in connection_genes.items(): + in_node_id, out_node_id = key + if cg.enabled and in_node_id in used_pins and out_node_id in used_pins: + used_connection_genes[key] = copy.deepcopy(cg) + + return used_node_genes, used_connection_genes diff --git a/tests/bad_configuration0 b/tests/bad_configuration0 new file mode 100644 index 00000000..9f894f96 --- /dev/null +++ b/tests/bad_configuration0 @@ -0,0 +1,79 @@ +[NEAT] +fitness_criterion = max +fitness_threshold = 0.9 +pop_size = 150 +reset_on_extinction = False + +[DefaultGenome] +# node activation options +activation_default = squanchy +activation_mutate_rate = 0.0 +activation_options = sigmoid + +# node aggregation options +aggregation_default = sum +aggregation_mutate_rate = 0.0 +aggregation_options = sum + +# node bias options +bias_init_mean = 0.0 +bias_init_stdev = 1.0 +bias_max_value = 30.0 +bias_min_value = -30.0 +bias_mutate_power = 0.5 +bias_mutate_rate = 0.7 +bias_replace_rate = 0.1 + +# genome compatibility options +compatibility_disjoint_coefficient = 1.0 +compatibility_weight_coefficient = 0.5 + +# connection add/remove rates +conn_add_prob = 0.5 +conn_delete_prob = 0.5 + +# connection enable options +enabled_default = True +enabled_mutate_rate = 0.01 + +feed_forward = True +initial_connection = full + +# node add/remove rates +node_add_prob = 0.2 +node_delete_prob = 0.2 + +# network parameters +num_hidden = 0 +num_inputs = 2 +num_outputs = 1 + +# node response options +response_init_mean = 1.0 +response_init_stdev = 0.0 +response_max_value = 30.0 +response_min_value = -30.0 +response_mutate_power = 0.0 +response_mutate_rate = 0.0 +response_replace_rate = 0.0 + +# connection weight options +weight_init_mean = 0.0 +weight_init_stdev = 1.0 +weight_max_value = 30 +weight_min_value = -30 +weight_mutate_power = 0.5 +weight_mutate_rate = 0.8 +weight_replace_rate = 0.1 + +[DefaultSpeciesSet] +compatibility_threshold = 3.0 + +[DefaultStagnation] +species_fitness_func = max +max_stagnation = 20 + +[DefaultReproduction] +elitism = 2 +survival_threshold = 0.2 + diff --git a/tests/test_config.py b/tests/test_config.py index 1afcea55..e0b39ecf 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -15,20 +15,10 @@ def test_nonexistent_config(): passed = 'No such config file' in str(e) assert passed - -# Re the below - a partial version is in test_simple_run.py -# TODO: fix this test -# def test_bad_config_activation(): -# """Check that an unknown activation function raises an Exception with -# the appropriate message.""" -# passed = False -# try: -# local_dir = os.path.dirname(__file__) -# c = Config(os.path.join(local_dir, 'bad_configuration1')) -# except Exception as e: -# print(repr(e)) -# passed = 'Invalid activation function name' in str(e) -# assert passed +def test_bad_config_default_activation(): + """Check that an activation function default not in the list of options + raises an Exception with the appropriate message.""" + test_bad_config_RuntimeError(config_file='bad_configuration0') def test_bad_config_unknown_option(): """Check that an unknown option (at least in some sections) raises an exception."""