From 22843ab47cf60d5bac1ec186b38faa38d4a40b00 Mon Sep 17 00:00:00 2001 From: CodeReclaimers Date: Tue, 22 Dec 2015 16:23:56 -0500 Subject: [PATCH] Simplified Genome and FFGenome._mutate_add_connection, and fixed a bug which systematically prevented connections from being added to some nodes. Temporarily commented connection adding code in Genome and FFGenome.add_hidden_nodes, as these always made the network fully connected regardless of configuration. Removed the scary "replace the __dict__" code in checkpoint restoration. Population now throws an exception when all species go extinct, as this is probably something the user should handle. Simplified speciation code. Simplified stats logging code, and keep more detailed per-generation fitness statistics. Added tests to exercise more of the genome-handling code. Simplification of XOR example. visualize.draw_net now takes an optional dictionary which is used to provide labels to be used in place of node IDs in the rendered network. Population now tracks the number of fitness function evaluations directly. Fixed bug (elitism setting was being interpreted as a float instead of an int). Removed compatibility threshold adjustment to control the number of species. Removed species age-adjusted fitness scheme as does not appear to be necessary: the speciation and stagnation mechanisms already provide the same benefit (as far as I can tell). Removed now-unused configuration items from examples. General cleanup of code and comments. --- README.md | 2 - docs/config_file.rst | 14 +- .../pole_balancing/single_pole/ctrnn_config | 8 +- .../single_pole/ctrnn_evolve.py | 4 +- examples/pole_balancing/single_pole/nn_config | 8 +- .../pole_balancing/single_pole/nn_evolve.py | 4 +- examples/xor/xor2.py | 36 +- examples/xor/xor2_config | 6 - examples/xor/xor2_parallel.py | 6 +- examples/xor/xor2_spiking.py | 8 +- neat/config.py | 8 +- neat/diversity.py | 12 +- neat/genes.py | 3 - neat/genome.py | 198 +++++------ neat/population.py | 319 +++++++----------- neat/species.py | 42 +-- neat/visualize.py | 36 +- tests/test_configuration | 6 - tests/test_genome.py | 88 ++++- 19 files changed, 365 insertions(+), 443 deletions(-) diff --git a/README.md b/README.md index 9628fca2..863b37c4 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@ [![Build Status](https://travis-ci.org/CodeReclaimers/neat-python.svg)](https://travis-ci.org/CodeReclaimers/neat-python) - [![Code Issues](https://www.quantifiedcode.com/api/v1/project/2bb1d19f57684f4589cb4700f99dd75e/badge.svg)](https://www.quantifiedcode.com/app/project/2bb1d19f57684f4589cb4700f99dd75e) - [![Coverage Status](https://coveralls.io/repos/CodeReclaimers/neat-python/badge.svg?branch=master&service=github)](https://coveralls.io/github/CodeReclaimers/neat-python?branch=master) ## About ## diff --git a/docs/config_file.rst b/docs/config_file.rst index e8d550b1..11f1dc82 100644 --- a/docs/config_file.rst +++ b/docs/config_file.rst @@ -64,14 +64,12 @@ NEAT settings. * *prob_toggle_link* The probability that the enabled status of a connection will be toggled. * *elitism* - The number of individuals in each species that will be preserved from one generation to the next. + The number of most fit individuals in each species that will be preserved as-is from one generation to the next. [genotype compatibility] section -------------------------------- * *compatibility_threshold* Individuals whose genomic distance is less than this threshold are considered to be in the same species. -* *compatibility_change* - The amount by which *compatibility_threshold* may be adjusted during a generation to maintain target *species_size*. * *excess_coefficient* The coefficient for the excess gene count's contribution to the genomic distance. * *disjoint_coefficient* @@ -81,18 +79,8 @@ NEAT settings. [species] section ----------------- -* *species_size* - The target number of species to maintain. When the number of species is different from *species_size*, *compatibility_threshold* will be adjusted up or down as necessary to attempt to return to *species_size*. * *survival_threshold* The fraction for each species allowed to reproduce on each generation. -* *old_threshold* - The number of generations beyond which species are considered old. -* *youth_threshold* - The number of generations below which species are considered young. -* *old_penalty* - The multiplicative fitness adjustment applied to old species' average fitness. This value is typically on (0.0, 1.0]. -* *youth_boost* - The multiplicative fitness adjustment applied to young species' average fitness. This value is typically on [1.0, 2.0]. * *max_stagnation* Species that have not shown improvement in more than this number of generations will be considered stagnant and removed. diff --git a/examples/pole_balancing/single_pole/ctrnn_config b/examples/pole_balancing/single_pole/ctrnn_config index 7c06c26c..6142826a 100644 --- a/examples/pole_balancing/single_pole/ctrnn_config +++ b/examples/pole_balancing/single_pole/ctrnn_config @@ -26,20 +26,14 @@ prob_mutate_weight = 0.8 prob_replace_weight = 0.1 weight_mutation_power = 1.0 prob_toggle_link = 0.01 -elitism = 1 +elitism = 2 [genotype compatibility] compatibility_threshold = 3.0 -compatibility_change = 0.0 excess_coefficient = 1.0 disjoint_coefficient = 1.0 weight_coefficient = 0.4 [species] -species_size = 10 survival_threshold = 0.2 -old_threshold = 80 -youth_threshold = 10 -old_penalty = 1.0 -youth_boost = 1.0 max_stagnation = 20 diff --git a/examples/pole_balancing/single_pole/ctrnn_evolve.py b/examples/pole_balancing/single_pole/ctrnn_evolve.py index 245eb26e..15b34bf9 100644 --- a/examples/pole_balancing/single_pole/ctrnn_evolve.py +++ b/examples/pole_balancing/single_pole/ctrnn_evolve.py @@ -27,15 +27,15 @@ def fitness_function(genomes): pop.epoch(fitness_function, 2000, report=1, save_best=0) # Save the winner. +print('Number of evaluations: {0:d}'.format(pop.total_evaluations)) winner = pop.most_fit_genomes[-1] -print('Number of evaluations: {0:d}'.format(winner.ID)) with open('ctrnn_winner_genome', 'wb') as f: pickle.dump(winner, f) print(winner) # Plot the evolution of the best/average fitness. -visualize.plot_stats(pop.most_fit_genomes, pop.avg_fitness_scores, ylog=True, filename="ctrnn_fitness.svg") +visualize.plot_stats(pop.most_fit_genomes, pop.fitness_scores, ylog=True, filename="ctrnn_fitness.svg") # Visualizes speciation visualize.plot_species(pop.species_log, filename="ctrnn_speciation.svg") # Visualize the best network. diff --git a/examples/pole_balancing/single_pole/nn_config b/examples/pole_balancing/single_pole/nn_config index b49302fe..80b547e3 100644 --- a/examples/pole_balancing/single_pole/nn_config +++ b/examples/pole_balancing/single_pole/nn_config @@ -26,22 +26,16 @@ prob_mutate_weight = 0.9 prob_replace_weight = 0.1 weight_mutation_power = 1.5 prob_toggle_link = 0.01 -elitism = 1 +elitism = 2 [genotype compatibility] compatibility_threshold = 3.0 -compatibility_change = 0.0 excess_coefficient = 1.0 disjoint_coefficient = 1.0 weight_coefficient = 0.4 [species] -species_size = 10 survival_threshold = 0.2 -old_threshold = 80 -youth_threshold = 10 -old_penalty = 1.0 -youth_boost = 1.0 max_stagnation = 20 diff --git a/examples/pole_balancing/single_pole/nn_evolve.py b/examples/pole_balancing/single_pole/nn_evolve.py index 8e9d0553..588be0e7 100644 --- a/examples/pole_balancing/single_pole/nn_evolve.py +++ b/examples/pole_balancing/single_pole/nn_evolve.py @@ -27,13 +27,13 @@ def fitness_function(genomes): pop.epoch(fitness_function, 1000) # Save the winner. +print('Number of evaluations: {0:d}'.format(pop.total_evaluations)) winner = pop.most_fit_genomes[-1] -print('Number of evaluations: {0:d}'.format(winner.ID)) with open('nn_winner_genome', 'wb') as f: pickle.dump(winner, f) # Plot the evolution of the best/average fitness. -visualize.plot_stats(pop.most_fit_genomes, pop.avg_fitness_scores, ylog=True, filename="nn_fitness.svg") +visualize.plot_stats(pop.most_fit_genomes, pop.fitness_scores, ylog=True, filename="nn_fitness.svg") # Visualizes speciation visualize.plot_species(pop.species_log, filename="nn_speciation.svg") # Visualize the best network. diff --git a/examples/xor/xor2.py b/examples/xor/xor2.py index 2e328284..ca4d2673 100644 --- a/examples/xor/xor2.py +++ b/examples/xor/xor2.py @@ -1,8 +1,7 @@ """ 2-input XOR example """ from __future__ import print_function -from neat import nn -from neat import population, visualize +from neat import nn, population, visualize xor_inputs = [[0, 0], [0, 1], [1, 0], [1, 1]] xor_outputs = [0, 1, 1, 0] @@ -23,25 +22,20 @@ def eval_fitness(genomes): g.fitness = 1 - error -def run(): - pop = population.Population('xor2_config') - pop.epoch(eval_fitness, 300) +pop = population.Population('xor2_config') +pop.epoch(eval_fitness, 300) - winner = pop.most_fit_genomes[-1] - print('Number of evaluations: {0:d}'.format(winner.ID)) +print('Number of evaluations: {0}'.format(pop.total_evaluations)) - # Verify network output against training data. - print('\nBest network output:') - net = nn.create_feed_forward_phenotype(winner) - for inputs, expected in zip(xor_inputs, xor_outputs): - output = net.serial_activate(inputs) - print("expected {0:1.5f} got {1:1.5f}".format(expected, output[0])) +# Verify network output against training data. +print('\nBest network output:') +winner = pop.most_fit_genomes[-1] +net = nn.create_feed_forward_phenotype(winner) +for inputs, expected in zip(xor_inputs, xor_outputs): + output = net.serial_activate(inputs) + print("expected {0:1.5f} got {1:1.5f}".format(expected, output[0])) - # Visualize the winner network and plot statistics. - visualize.plot_stats(pop.most_fit_genomes, pop.avg_fitness_scores) - visualize.plot_species(pop.species_log) - visualize.draw_net(winner, view=True) - - -if __name__ == '__main__': - run() \ No newline at end of file +# Visualize the winner network and plot statistics. +visualize.plot_stats(pop.most_fit_genomes, pop.fitness_scores) +visualize.plot_species(pop.species_log) +visualize.draw_net(winner, view=True) diff --git a/examples/xor/xor2_config b/examples/xor/xor2_config index cd1609b5..c70f750b 100644 --- a/examples/xor/xor2_config +++ b/examples/xor/xor2_config @@ -37,16 +37,10 @@ elitism = 1 [genotype compatibility] compatibility_threshold = 3.0 -compatibility_change = 0.0 excess_coefficient = 1.0 disjoint_coefficient = 1.0 weight_coefficient = 0.4 [species] -species_size = 10 survival_threshold = 0.2 -old_threshold = 30 -youth_threshold = 10 -old_penalty = 1.0 -youth_boost = 1.0 max_stagnation = 100 diff --git a/examples/xor/xor2_parallel.py b/examples/xor/xor2_parallel.py index 8799a5d3..ad2ec33b 100644 --- a/examples/xor/xor2_parallel.py +++ b/examples/xor/xor2_parallel.py @@ -65,18 +65,18 @@ def fitness(genomes): print("total evolution time {0:.3f} sec".format((time.time() - t0))) print("time per generation {0:.3f} sec".format(((time.time() - t0) / pop.generation))) - winner = pop.most_fit_genomes[-1] - print('Number of evaluations: {0:d}'.format(winner.ID)) + print('Number of evaluations: {0:d}'.format(pop.total_evaluations)) # Verify network output against training data. print('\nBest network output:') + winner = pop.most_fit_genomes[-1] net = nn.create_feed_forward_phenotype(winner) for i, inputs in enumerate(xor_inputs): output = net.serial_activate(inputs) # serial activation print( "{0:1.5f} \t {1:1.5f}".format(xor_outputs[i], output[0])) # Visualize the winner network and plot statistics. - visualize.plot_stats(pop.most_fit_genomes, pop.avg_fitness_scores) + visualize.plot_stats(pop.most_fit_genomes, pop.fitness_scores) visualize.plot_species(pop.species_log) visualize.draw_net(winner, view=True) diff --git a/examples/xor/xor2_spiking.py b/examples/xor/xor2_spiking.py index c10f34b5..40b6eece 100644 --- a/examples/xor/xor2_spiking.py +++ b/examples/xor/xor2_spiking.py @@ -93,13 +93,13 @@ def run(): pop = population.Population(config) pop.epoch(eval_fitness, 200) - winner = pop.most_fit_genomes[-1] - print('Number of evaluations: {0:d}'.format(winner.ID)) + print('Number of evaluations: {0}'.format(pop.total_evaluations)) # Visualize the winner network and plot statistics. - visualize.plot_stats(pop.most_fit_genomes, pop.avg_fitness_scores) + winner = pop.most_fit_genomes[-1] + visualize.draw_net(winner, view=True, node_names={0:'A', 1:'B', 2:'Out1', 3:'Out2'}) + visualize.plot_stats(pop.most_fit_genomes, pop.fitness_scores) visualize.plot_species(pop.species_log) - visualize.draw_net(winner, view=True) # Verify network output against training data. print('\nBest network output:') diff --git a/neat/config.py b/neat/config.py index fffc204b..d62a2b24 100644 --- a/neat/config.py +++ b/neat/config.py @@ -54,20 +54,14 @@ def __init__(self, filename): self.prob_replace_weight = float(parameters.get('genetic', 'prob_replace_weight')) self.weight_mutation_power = float(parameters.get('genetic', 'weight_mutation_power')) self.prob_toggle_link = float(parameters.get('genetic', 'prob_toggle_link')) - self.elitism = float(parameters.get('genetic', 'elitism')) + self.elitism = int(parameters.get('genetic', 'elitism')) # genotype compatibility self.compatibility_threshold = float(parameters.get('genotype compatibility', 'compatibility_threshold')) - self.compatibility_change = float(parameters.get('genotype compatibility', 'compatibility_change')) self.excess_coefficient = float(parameters.get('genotype compatibility', 'excess_coefficient')) self.disjoint_coefficient = float(parameters.get('genotype compatibility', 'disjoint_coefficient')) self.weight_coefficient = float(parameters.get('genotype compatibility', 'weight_coefficient')) # species - self.species_size = int(parameters.get('species', 'species_size')) self.survival_threshold = float(parameters.get('species', 'survival_threshold')) - self.old_threshold = int(parameters.get('species', 'old_threshold')) - self.youth_threshold = int(parameters.get('species', 'youth_threshold')) - self.old_penalty = float(parameters.get('species', 'old_penalty')) - self.youth_boost = float(parameters.get('species', 'youth_boost')) self.max_stagnation = int(parameters.get('species', 'max_stagnation')) diff --git a/neat/diversity.py b/neat/diversity.py index 905d172d..3ecea828 100644 --- a/neat/diversity.py +++ b/neat/diversity.py @@ -3,11 +3,12 @@ ''' from math import ceil -class AgedFitnessSharing(object): + +class ExplicitFitnessSharing(object): ''' This class encapsulates a fitness sharing scheme. It is responsible for computing the number of individuals to be spawned for each species in the - next generation, based on species fitness, age, and size. + next generation, based on species fitness and size. Fitness inside a species is shared by all its members, so that a species that happens to end up with a large initial number of members is less @@ -35,13 +36,6 @@ def compute_spawn_amount(self, species): for f, s in zip(fitnesses, species): # Make all adjusted fitnesses positive, and apply adjustment for population size. af = (f + fitness_shift) / len(s.members) - - # Apply adjustments for species age. - if s.age < self.config.youth_threshold: - af *= self.config.youth_boost - elif s.age > self.config.old_threshold: - af *= self.config.old_penalty - adjusted_fitnesses.append(af) total_adjusted_fitness += af diff --git a/neat/genes.py b/neat/genes.py index 64e456e5..e107b873 100644 --- a/neat/genes.py +++ b/neat/genes.py @@ -52,9 +52,6 @@ def mutate(self, config): self.__mutate_response(config) - - - class ConnectionGene(object): indexer = Indexer(0) __innovations = {} diff --git a/neat/genome.py b/neat/genome.py index 7fb087b0..fbb7d3f4 100644 --- a/neat/genome.py +++ b/neat/genome.py @@ -130,30 +130,24 @@ def _mutate_add_node(self): return ng, conn_to_split # the return is only used in genome_feedforward def _mutate_add_connection(self): - # Only for recurrent networks - total_possible_conns = (len(self.node_genes) - self.num_inputs) * len(self.node_genes) - remaining_conns = total_possible_conns - len(self.conn_genes) - # Check if new connection can be added: - if remaining_conns > 0: - n = randint(0, remaining_conns - 1) - count = 0 - # Count connections - for in_node in self.node_genes.values(): - for out_node in self.node_genes.values(): - # TODO: We do this filtering of input/output/hidden nodes a lot; they should probably - # be separate collections. - if out_node.type == 'INPUT': - continue - - if (in_node.ID, out_node.ID) not in self.conn_genes.keys(): - # Free connection - if count == n: # Connection to create - weight = gauss(0, self.config.weight_stdev) - cg = self._conn_gene_type(in_node.ID, out_node.ID, weight, True) - self.conn_genes[cg.key] = cg - return - else: - count += 1 + ''' + Attempt to add a new connection, the only restriction being that the output + node cannot be one of the network input nodes. + ''' + in_node = choice(list(self.node_genes.values())) + + # TODO: We do this filtering of input/output/hidden nodes a lot; + # they should probably be separate collections. + possible_outputs = [n for n in self.node_genes.values() if n.type != 'INPUT'] + out_node = choice(possible_outputs) + + # Only create the connection if it doesn't already exist. + key = (in_node.ID, out_node.ID) + if key not in self.conn_genes: + weight = gauss(0, self.config.weight_stdev) + enabled = choice([False, True]) + cg = self._conn_gene_type(in_node.ID, out_node.ID, weight, enabled) + self.conn_genes[cg.key] = cg def _mutate_delete_node(self): # Do nothing if there are no hidden nodes. @@ -205,19 +199,25 @@ def distance(self, other): genome1 = other genome2 = self + # If the longest genome is empty, there is nothing to do. + if not genome1.conn_genes: + return 0.0 + N = len(genome1.conn_genes) weight_diff = 0 matching = 0 disjoint = 0 excess = 0 - max_cg_genome2 = max(genome2.conn_genes.values()) + max_cg_genome2 = None + if genome2.conn_genes: + max_cg_genome2 = max(genome2.conn_genes.values()) for cg1 in genome1.conn_genes.values(): try: cg2 = genome2.conn_genes[cg1.key] except KeyError: - if cg1 > max_cg_genome2: + if max_cg_genome2 is not None and cg1 > max_cg_genome2: excess += 1 else: disjoint += 1 @@ -235,20 +235,13 @@ def distance(self, other): return distance def size(self): - """ Defines genome 'complexity': number of hidden nodes plus - number of enabled connections (bias is not considered) - """ - # number of hidden nodes - num_hidden = len(self.node_genes) - self.num_inputs - self.num_outputs - # number of enabled connections - conns_enabled = sum([1 for cg in self.conn_genes.values() if cg.enabled is True]) - - return num_hidden, conns_enabled + '''Returns genome 'complexity', taken to be (number of hidden nodes, number of enabled connections)''' + num_hidden_nodes = len(self.node_genes) - self.num_inputs - self.num_outputs + num_enabled_connections = sum([1 for cg in self.conn_genes.values() if cg.enabled is True]) + return num_hidden_nodes, num_enabled_connections def __lt__(self, other): - """ - Compare genomes by their fitness. - """ + '''Order genomes by fitness.''' return self.fitness < other.fitness def __str__(self): @@ -271,26 +264,26 @@ def add_hidden_nodes(self, num_hidden): assert node_gene.ID not in self.node_genes self.node_genes[node_gene.ID] = node_gene node_id += 1 - # Connect all nodes to it - for pre in self.node_genes.values(): - weight = gauss(0, self.config.weight_stdev) - cg = self._conn_gene_type(pre.ID, node_gene.ID, weight, True) - self.conn_genes[cg.key] = cg - # Connect it to all nodes except input nodes - for post in self.node_genes.values(): - if post.type == 'INPUT': - continue - weight = gauss(0, self.config.weight_stdev) - cg = self._conn_gene_type(node_gene.ID, post.ID, weight, True) - self.conn_genes[cg.key] = cg + # TODO: Add connections based on configuration. + # Connect all nodes to it + # for pre in self.node_genes.values(): + # weight = gauss(0, self.config.weight_stdev) + # cg = self._conn_gene_type(pre.ID, node_gene.ID, weight, True) + # self.conn_genes[cg.key] = cg + # # Connect it to all nodes except input nodes + # for post in self.node_genes.values(): + # if post.type == 'INPUT': + # continue + # + # weight = gauss(0, self.config.weight_stdev) + # cg = self._conn_gene_type(node_gene.ID, post.ID, weight, True) + # self.conn_genes[cg.key] = cg @classmethod def create_unconnected(cls, config, node_gene_type, conn_gene_type): - """ - Factory method - Creates a genome for an unconnected feed-forward network with no hidden nodes. - """ + '''Create a genome for a network with no hidden nodes and no connections.''' + c = cls(config, 0, 0, node_gene_type, conn_gene_type) node_id = 0 # Create node genes @@ -298,7 +291,7 @@ def create_unconnected(cls, config, node_gene_type, conn_gene_type): assert node_id not in c.node_genes c.node_genes[node_id] = c._node_gene_type(node_id, 'INPUT') node_id += 1 - # c.num_inputs += num_input + for i in range(config.output_nodes): node_gene = c._node_gene_type(node_id, node_type='OUTPUT', @@ -306,15 +299,15 @@ def create_unconnected(cls, config, node_gene_type, conn_gene_type): assert node_gene.ID not in c.node_genes c.node_genes[node_gene.ID] = node_gene node_id += 1 + assert node_id == len(c.node_genes) return c @classmethod def create_minimally_connected(cls, config, node_gene_type, conn_gene_type): """ - Factory method - Creates a genome for a minimally connected feed-forward network with no hidden nodes. That is, - each output node will have a single connection from a randomly chosen input node. + Create a genome for a minimally connected feed-forward network with no hidden nodes. + Each output node will have a single connection from a randomly chosen input node. """ c = cls.create_unconnected(config, node_gene_type, conn_gene_type) for node_gene in c.node_genes.values(): @@ -338,8 +331,7 @@ def create_minimally_connected(cls, config, node_gene_type, conn_gene_type): @classmethod def create_fully_connected(cls, config, node_gene_type, conn_gene_type): """ - Factory method - Creates a genome for a fully connected feed-forward network with no hidden nodes. + Create a genome for a fully connected feed-forward network with no hidden nodes. """ c = cls.create_unconnected(config, node_gene_type, conn_gene_type) for node_gene in c.node_genes.values(): @@ -391,36 +383,24 @@ def _mutate_add_node(self): return ng, split_conn def _mutate_add_connection(self): - # Only for feed-forward networks - nhidden = len(self.node_order) - nout = len(self.node_genes) - self.num_inputs - nhidden - - total_possible_conns = (nhidden + nout) * (self.num_inputs + nhidden) - sum(range(nhidden + 1)) - - remaining_conns = total_possible_conns - len(self.conn_genes) - # Check if new connection can be added: - if remaining_conns > 0: - n = randint(0, remaining_conns - 1) - count = 0 - # Count connections - for in_node in self.node_genes.values(): - if in_node.type == 'OUTPUT': - continue - - for out_node in self.node_genes.values(): - if out_node.type == 'INPUT': - continue - - if (in_node.ID, out_node.ID) not in self.conn_genes.keys() and \ - self.__is_connection_feedforward(in_node, out_node): - # Free connection - if count == n: # Connection to create - weight = gauss(0, self.config.weight_stdev) - cg = self._conn_gene_type(in_node.ID, out_node.ID, weight, True) - self.conn_genes[cg.key] = cg - return - else: - count += 1 + ''' + Attempt to add a new connection, with the restrictions that (1) the output node + cannot be one of the network input nodes, and (2) the connection must be feed-forward. + ''' + possible_inputs = [n for n in self.node_genes.values() if n.type != 'OUTPUT'] + possible_outputs = [n for n in self.node_genes.values() if n.type != 'INPUT'] + + in_node = choice(possible_inputs) + out_node = choice(possible_outputs) + + # Only create the connection if it's feed-forward and it doesn't already exist. + if self.__is_connection_feedforward(in_node, out_node): + key = (in_node.ID, out_node.ID) + if key not in self.conn_genes: + weight = gauss(0, self.config.weight_stdev) + enabled = choice([False, True]) + cg = self._conn_gene_type(in_node.ID, out_node.ID, weight, enabled) + self.conn_genes[cg.key] = cg def _mutate_delete_node(self): deleted_id = super(FFGenome, self)._mutate_delete_node() @@ -449,25 +429,25 @@ def add_hidden_nodes(self, num_hidden): self.node_order.append(node_gene.ID) node_id += 1 # Connect all input nodes to it - for pre in self.node_genes.values(): - if pre.type == 'INPUT': - weight = gauss(0, self.config.weight_stdev) - cg = self._conn_gene_type(pre.ID, node_gene.ID, weight, True) - self.conn_genes[cg.key] = cg - assert self.__is_connection_feedforward(pre, node_gene) - # Connect all previous hidden nodes to it - for pre_id in self.node_order[:-1]: - assert pre_id != node_gene.ID - weight = gauss(0, self.config.weight_stdev) - cg = self._conn_gene_type(pre_id, node_gene.ID, weight, True) - self.conn_genes[cg.key] = cg - # Connect it to all output nodes - for post in self.node_genes.values(): - if post.type == 'OUTPUT': - weight = gauss(0, self.config.weight_stdev) - cg = self._conn_gene_type(node_gene.ID, post.ID, weight, True) - self.conn_genes[cg.key] = cg - assert self.__is_connection_feedforward(node_gene, post) + # for pre in self.node_genes.values(): + # if pre.type == 'INPUT': + # weight = gauss(0, self.config.weight_stdev) + # cg = self._conn_gene_type(pre.ID, node_gene.ID, weight, True) + # self.conn_genes[cg.key] = cg + # assert self.__is_connection_feedforward(pre, node_gene) + # # Connect all previous hidden nodes to it + # for pre_id in self.node_order[:-1]: + # assert pre_id != node_gene.ID + # weight = gauss(0, self.config.weight_stdev) + # cg = self._conn_gene_type(pre_id, node_gene.ID, weight, True) + # self.conn_genes[cg.key] = cg + # # Connect it to all output nodes + # for post in self.node_genes.values(): + # if post.type == 'OUTPUT': + # weight = gauss(0, self.config.weight_stdev) + # cg = self._conn_gene_type(node_gene.ID, post.ID, weight, True) + # self.conn_genes[cg.key] = cg + # assert self.__is_connection_feedforward(node_gene, post) def __str__(self): s = super(FFGenome, self).__str__() diff --git a/neat/population.py b/neat/population.py index 1e447b2e..1dfd88f2 100644 --- a/neat/population.py +++ b/neat/population.py @@ -11,7 +11,10 @@ from neat.genes import NodeGene, ConnectionGene from neat.species import Species from neat.math_util import mean, stdev -from neat.diversity import AgedFitnessSharing +from neat.diversity import ExplicitFitnessSharing + +class MassExtinctionException(Exception): + pass class Population(object): @@ -19,70 +22,70 @@ class Population(object): def __init__(self, config, checkpoint_file=None, initial_population=None, node_gene_type=NodeGene, conn_gene_type=ConnectionGene, - diversity_type=AgedFitnessSharing): + diversity_type=ExplicitFitnessSharing): # If config is not a Config object, assume it is a path to the config file. if not isinstance(config, Config): config = Config(config) self.config = config - - self.population = None + # TODO: Move node_gene_type, conn_gene_type, and diversity_type to the configuration object. self.node_gene_type = node_gene_type self.conn_gene_type = conn_gene_type self.diversity = diversity_type(self.config) + self.population = None + self.species = [] + self.species_log = [] + self.fitness_scores = [] + self.most_fit_genomes = [] + self.generation = -1 + self.total_evaluations = 0 + if checkpoint_file: - # Start from a saved checkpoint. - self.__resume_checkpoint(checkpoint_file) + assert initial_population is None + self._load_checkpoint(checkpoint_file) + elif initial_population is None: + self._create_population() else: - # currently living species - self.__species = [] - # species history - self.species_log = [] + self.population = initial_population - # List of statistics for all generations. - self.avg_fitness_scores = [] - self.most_fit_genomes = [] + # Partition the population into species based on current configuration. + self._speciate() - if initial_population is None: - self.__create_population() - else: - self.population = initial_population - self.generation = -1 - - def __resume_checkpoint(self, checkpoint): - ''' - Resumes the simulation from a previous saved point. This is done by swapping out our existing - __dict__ with the loaded population's. - ''' - # TODO: Wouldn't it just be better to create a class method to load and return the stored Population - # object as-is? I don't know if there are hidden side effects to directly replacing __dict__. + def _load_checkpoint(self, checkpoint): + '''Resumes the simulation from a previous saved point.''' with gzip.open(checkpoint) as f: - print('Resuming from a previous point: {0!s}'.format(checkpoint)) - # when unpickling __init__ is not called again - previous_pop = pickle.load(f) - self.__dict__ = previous_pop.__dict__ - - print('Loading random state') - random.setstate(pickle.load(f)) - - def __create_checkpoint(self, report): - """ Saves the current simulation state. """ - if report: - print('Creating checkpoint file at generation: {0:d}'.format(self.generation)) - - with gzip.open('checkpoint_' + str(self.generation), 'w', compresslevel=5) as f: - # Write the entire population state. - pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL) - # Remember the current random number state. - pickle.dump(random.getstate(), f, protocol=2) - - def __create_population(self): + print('Resuming from a previous point: {0}'.format(checkpoint)) + + (self.population, + self.species, + self.species_log, + self.fitness_scores, + self.most_fit_genomes, + self.generation, + random_state) = pickle.load(f) + + random.setstate(random_state) + + def _create_checkpoint(self): + """ Save the current simulation state. """ + fn = 'neat-checkpoint-{0}'.format(self.generation) + with gzip.open(fn, 'w', compresslevel=5) as f: + data = (self.population, + self.species, + self.species_log, + self.fitness_scores, + self.most_fit_genomes, + self.generation, + random.getstate()) + pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) + + def _create_population(self): if self.config.feedforward: - genotypes = FFGenome + genotype = FFGenome else: - genotypes = Genome + genotype = Genome self.population = [] # TODO: Add FS-NEAT support, which creates an empty connection set, and then performs a @@ -93,30 +96,24 @@ def __create_population(self): # 3. FS-NEAT connected (one random connection) if self.config.fully_connected: for i in range(self.config.pop_size): - g = genotypes.create_fully_connected(self.config, self.node_gene_type, self.conn_gene_type) + g = genotype.create_fully_connected(self.config, self.node_gene_type, self.conn_gene_type) self.population.append(g) else: for i in range(self.config.pop_size): - g = genotypes.create_minimally_connected(self.config, self.node_gene_type, self.conn_gene_type) + g = genotype.create_minimally_connected(self.config, self.node_gene_type, self.conn_gene_type) self.population.append(g) if self.config.hidden_nodes > 0: for g in self.population: g.add_hidden_nodes(self.config.hidden_nodes) - def __repr__(self): - s = "Population size: {0:d}".format(self.config.pop_size) - s += "\nTotal species: {0:d}".format(len(self.__species)) - return s - - def __speciate(self, report): - """ Group genomes into species by similarity """ - # Speciate the population + def _speciate(self): + """Group genomes into species by genetic similarity.""" for individual in self.population: # Find the species with the most similar representative. min_distance = None closest_species = None - for s in self.__species: + for s in self.species: distance = individual.distance(s.representative) if distance < self.config.compatibility_threshold: if min_distance is None or distance < min_distance: @@ -127,51 +124,19 @@ def __speciate(self, report): closest_species.add(individual) else: # No species is similar enough, create a new species for this individual. - self.__species.append(Species(individual)) + self.species.append(Species(individual)) - # python technical note: - # we need a "working copy" list when removing elements while looping - # otherwise we might end up having sync issues - for s in self.__species[:]: - # this happens when no genomes are compatible with the species + # Verify that no species are empty. + for s in self.species: if not s.members: - #raise Exception('TODO: fix this') - if report: - print("Removing species {0:d} for being empty".format(s.ID)) - # remove empty species - self.__species.remove(s) - - self.__set_compatibility_threshold(report) - - def __set_compatibility_threshold(self, report): - """ Controls compatibility threshold """ - t = self.config.compatibility_threshold - dt = self.config.compatibility_change - if len(self.__species) > self.config.species_size: - t += dt - elif len(self.__species) < self.config.species_size: - t = max(0.0, t - dt) - - if self.config.compatibility_threshold != t: - if report: - print("Adjusted compatibility threshold to {0:f}".format(t)) - self.config.compatibility_threshold = t - - def __log_species(self): - """ Logging species data for visualizing speciation """ - temp = [] - if self.__species: - higher = max([s.ID for s in self.__species]) - for i in range(1, higher + 1): - found_species = False - for s in self.__species: - if i == s.ID: - temp.append(len(s.members)) - found_species = True - break - if not found_species: - temp.append(0) - self.species_log.append(temp) + raise Exception('TODO: fix empty species bug') + + def _log_stats(self): + """ Gather data for visualization/reporting purposes. """ + species_sizes = dict((s.ID, len(s.members)) for s in self.species) + self.species_log.append(species_sizes) + self.most_fit_genomes.append(copy.deepcopy(max(self.population))) + self.fitness_scores.append([c.fitness for c in self.population]) def epoch(self, fitness_function, n, report=True, save_best=False, checkpoint_interval=10, checkpoint_generation=None): @@ -190,122 +155,82 @@ def epoch(self, fitness_function, n, report=True, save_best=False, checkpoint_in self.generation += 1 if report: - print('\n ****** Running generation {0:d} ****** \n'.format(self.generation)) + print('\n ****** Running generation {0} ****** \n'.format(self.generation)) # Evaluate individuals fitness_function(self.population) - # Speciates the population - self.__speciate(report) + self.total_evaluations += len(self.population) - # Current generation's best genome - self.most_fit_genomes.append(copy.deepcopy(max(self.population))) - # Current population's average fitness - self.avg_fitness_scores.append(mean([c.fitness for c in self.population])) + # Gather statistics. + self._log_stats() # Print some statistics best = self.most_fit_genomes[-1] - - # saves the best genome from the current generation + if report: + fit_mean = mean([c.fitness for c in self.population]) + fit_std = stdev([c.fitness for c in self.population]) + print('Population\'s average fitness: {0:3.5f} stdev: {1:3.5f}'.format(fit_mean, fit_std)) + print('Best fitness: {0:3.5f} - size: {1!r} - species {2} - id {3}'.format(best.fitness, best.size(), + best.species_id, best.ID)) + print('Species length: {0:d} totaling {1:d} individuals'.format(len(self.species), sum([len(s.members) for s in self.species]))) + print('Species ID : {0!s}'.format([s.ID for s in self.species])) + print('Each species size: {0!s}'.format([len(s.members) for s in self.species])) + print('Amount to spawn : {0!s}'.format([s.spawn_amount for s in self.species])) + print('Species age : {0}'.format([s.age for s in self.species])) + print('Species avg fit : {0!r}'.format([s.get_average_fitness() for s in self.species])) + print('Species no improv: {0!r}'.format([s.no_improvement_age for s in self.species])) + + # Saves the best genome from the current generation if requested. if save_best: f = open('best_genome_' + str(self.generation), 'w') pickle.dump(best, f) f.close() - # Stops the simulation - if best.fitness > self.config.max_fitness_threshold: + # End when the fitness threshold is reached. + if best.fitness >= self.config.max_fitness_threshold: if report: - print('\nBest individual in epoch {0!s} meets fitness threshold - complexity: {1!s}'.format( + print('\nBest individual in epoch {0} meets fitness threshold - complexity: {1!r}'.format( self.generation, best.size())) break - # Remove stagnated species and its members (except if it has the best genome) - for s in self.__species[:]: + # Remove stagnated species. + #TODO: Log species removal for visualization purposes. + new_species = [] + for s in self.species: s.update_stagnation() - if s.no_improvement_age > self.config.max_stagnation: + if s.no_improvement_age <= self.config.max_stagnation: + new_species.append(s) + else: if report: - print("\n Species {0:2d} (with {1:2d} individuals) is stagnated: removing it".format(s.ID, len(s.members))) - # removing species - self.__species.remove(s) - # removing all the species' members - # TODO: can be optimized! - for c in self.population[:]: - if c.species_id == s.ID: - self.population.remove(c) - - # Compute spawn levels for each remaining species - self.diversity.compute_spawn_amount(self.__species) - - # Verify that all species received non-zero spawn counts, as the speciation mechanism - # is intended to allow initially less-fit species time to improve before making them - # extinct via the stagnation mechanism. - for s in self.__species: - assert s.spawn_amount > 0 + print("\n Species {0} with {1} members is stagnated: removing it".format(s.ID, len(s.members))) + self.species = new_species - # Logging speciation stats - self.__log_species() - - if report: - if self.population: - std_dev = stdev([c.fitness for c in self.population]) - print('Population\'s average fitness: {0:3.5f} stdev: {1:3.5f}'.format(self.avg_fitness_scores[-1], std_dev)) - print('Best fitness: {0:2.12} - size: {1!r} - species {2} - id {3}'.format(best.fitness, best.size(), best.species_id, best.ID)) - print('Species length: {0:d} totaling {1:d} individuals'.format(len(self.__species), sum([len(s.members) for s in self.__species]))) - print('Species ID : {0!s}'.format([s.ID for s in self.__species])) - print('Each species size: {0!s}'.format([len(s.members) for s in self.__species])) - print('Amount to spawn : {0!s}'.format([s.spawn_amount for s in self.__species])) - print('Species age : {0}'.format([s.age for s in self.__species])) - print('Species avg fit : {0!s}'.format([s.get_average_fitness() for s in self.__species])) - print('Species no improv: {0!s}'.format([s.no_improvement_age for s in self.__species])) - else: + # Check for complete extinction. + if not self.species: + if report: print('All species extinct.') + raise MassExtinctionException() + + # Compute spawn levels for all current species and then reproduce. + self.diversity.compute_spawn_amount(self.species) + self.population = [] + for s in self.species: + # Verify that all species received non-zero spawn counts, as the speciation mechanism + # is intended to allow initially less fit species time to improve before making them + # extinct via the stagnation mechanism. + assert s.spawn_amount > 0 + self.population.extend(s.reproduce(self.config)) - # -------------------------- Producing new offspring -------------------------- # - new_population = [] # next generation's population - - # If no species are left, create a new population from scratch, otherwise top off - # population by reproducing existing species. - if self.__species: - for s in self.__species: - new_population.extend(s.reproduce(self.config)) - - # Controls under or overflow # - fill = self.config.pop_size - len(new_population) - if fill < 0: # overflow - if report: - print(' Removing {0:d} excess individual(s) from the new population'.format(-fill)) - # TODO: This is dangerous! I can't remove a species' representative! - new_population = new_population[:fill] # Removing the last added members - - if fill > 0: # underflow - if report: - print(' Producing {0:d} more individual(s) to fill up the new population'.format(fill)) - - while fill > 0: - # Selects a random genome from population - parent1 = random.choice(self.population) - # Search for a mate within the same species - found = False - for c in self.population: - # what if c is parent1 itself? - if c.species_id == parent1.species_id: - child = parent1.crossover(c) - new_population.append(child.mutate()) - found = True - break - if not found: - # If no mate was found, just mutate it - new_population.append(parent1.mutate()) - # new_population.append(genome.FFGenome.create_fully_connected()) - fill -= 1 - - assert self.config.pop_size == len(new_population), 'Different population sizes!' - # Updates current population - self.population = new_population - else: - self.__create_population() + self._speciate() if checkpoint_interval is not None and time.time() > t0 + 60 * checkpoint_interval: - self.__create_checkpoint(report) - t0 = time.time() # updates the counter + if report: + print('Creating timed checkpoint file at generation: {0}'.format(self.generation)) + self._create_checkpoint() + + # Update the checkpoint time. + t0 = time.time() elif checkpoint_generation is not None and self.generation % checkpoint_generation == 0: - self.__create_checkpoint(report) + if report: + print('Creating generation checkpoint file at generation: {0}'.format(self.generation)) + self._create_checkpoint() diff --git a/neat/species.py b/neat/species.py index 14529118..e71863e1 100644 --- a/neat/species.py +++ b/neat/species.py @@ -22,23 +22,14 @@ def __init__(self, first_individual, previous_id=None): def add(self, individual): individual.species_id = self.ID self.members.append(individual) - # choose a new random representative for the species - self.representative = random.choice(self.members) - - def __str__(self): - s = "\n Species {0:2d} size: {1:3d} age: {2:3d} spawn: {3:3d} ".format(self.ID, len(self.members), self.age, self.spawn_amount) - s += "\n No improvement: {0:3d} \t avg. fitness: {1:1.8f}".format(self.no_improvement_age, self.last_avg_fitness) - return s def get_average_fitness(self): - """ Returns the raw average fitness over all members in the species.""" + """ Returns the average fitness over all members in the species.""" return mean([c.fitness for c in self.members]) def update_stagnation(self): """ Updates no_improvement_age based on average fitness progress.""" fitness = self.get_average_fitness() - - # Check for increase in mean fitness and adjust "no improvement" count as necessary. if fitness > self.last_avg_fitness: self.last_avg_fitness = fitness self.no_improvement_age = 0 @@ -46,19 +37,18 @@ def update_stagnation(self): self.no_improvement_age += 1 def reproduce(self, config): - """ Returns a list of 'self.spawn_amount' new individuals """ + """ + Update species age, clear the current membership list, and return a list of 'self.spawn_amount' new individuals. + """ + self.age += 1 - offspring = [] # new offspring for this species - self.age += 1 # increment species age + # Sort with most fit members first. + self.members.sort(reverse=True) - self.members.sort() # sort species's members by their fitness - self.members.reverse() # best members first - - if config.elitism: - # TODO: Wouldn't it be better if we set elitism=2,3,4... - # depending on the size of each species? - offspring.append(self.members[0]) - self.spawn_amount -= 1 + offspring = [] + if config.elitism > 0: + offspring.extend(self.members[:config.elitism]) + self.spawn_amount -= config.elitism # Keep a fraction of the current population for reproduction. survivors = int(round(len(self.members) * config.survival_threshold)) @@ -69,19 +59,19 @@ def reproduce(self, config): while self.spawn_amount > 0: self.spawn_amount -= 1 - # Select two parents at random from the remaining members. + # Select two parents at random from the given set of members. parent1 = random.choice(self.members) parent2 = random.choice(self.members) - # Note that if the parents are not distinct, crossover should - # be idempotent. TODO: Write a test for that. + # Note that if the parents are not distinct, crossover should produce a + # genetically identical clone of the parent (but with a different ID). child = parent1.crossover(parent2) offspring.append(child.mutate()) - # reset species (new members will be added again when speciating) + # Reset species members--the speciation process in Population will repopulate this list. self.members = [] - # select a new random representative member + # Select a new random representative member from the new offspring. self.representative = random.choice(offspring) return offspring diff --git a/neat/visualize.py b/neat/visualize.py index 4b3075f2..63f7a069 100644 --- a/neat/visualize.py +++ b/neat/visualize.py @@ -20,7 +20,10 @@ warnings.warn('Could not import optional dependency NumPy.') -def plot_stats(best_genomes, avg_scores, ylog=False, view=False, filename='avg_fitness.svg'): +from neat.math_util import mean + + +def plot_stats(best_genomes, fitness_scores, ylog=False, view=False, filename='avg_fitness.svg'): """ Plots the population's average and best fitness. """ if plt is None: warnings.warn("This display is not available due to a missing optional dependency (matplotlib)") @@ -30,6 +33,8 @@ def plot_stats(best_genomes, avg_scores, ylog=False, view=False, filename='avg_f fitness = [c.fitness for c in best_genomes] + avg_scores = [mean(f) for f in fitness_scores] + plt.plot(generation, avg_scores, 'b-', label="average") plt.plot(generation, fitness, 'r-', label="best") @@ -97,11 +102,15 @@ def plot_species(species_log, view=False, filename='speciation.svg'): return num_generations = len(species_log) - num_species = max(map(len, species_log)) + all_species = set() + for gen_data in species_log: + for sid, scount in gen_data.items(): + all_species.add(sid) + + max_species = max(all_species) curves = [] - for gen in species_log: - species = [0] * num_species - species[:len(gen)] = gen + for gen_data in species_log: + species = [gen_data.get(sid, 0) for sid in range(max_species + 1)] curves.append(np.array(species)) curves = np.array(curves).T @@ -120,13 +129,18 @@ def plot_species(species_log, view=False, filename='speciation.svg'): plt.close() -def draw_net(genome, view=False, filename=None): +def draw_net(genome, view=False, filename=None, node_names=None): """ Receives a genome and draws a neural network with arbitrary topology. """ # Attributes for network nodes. if graphviz is None: warnings.warn("This display is not available due to a missing optional dependency (graphviz)") return + if node_names is None: + node_names = {} + + assert type(node_names) is dict + node_attrs = { 'shape': 'circle', 'fontsize': '9', @@ -147,15 +161,17 @@ def draw_net(genome, view=False, filename=None): for ng_id, ng in genome.node_genes.items(): if ng.type == 'INPUT': - dot.node(str(ng_id), _attributes=input_attrs) + name = node_names.get(ng_id, str(ng_id)) + dot.node(name, _attributes=input_attrs) for ng_id, ng in genome.node_genes.items(): if ng.type == 'OUTPUT': - dot.node(str(ng_id), _attributes=output_attrs) + name = node_names.get(ng_id, str(ng_id)) + dot.node(name, _attributes=output_attrs) for cg in genome.conn_genes.values(): - a = str(cg.in_node_id) - b = str(cg.out_node_id) + a = node_names.get(cg.in_node_id, str(cg.in_node_id)) + b = node_names.get(cg.out_node_id, str(cg.out_node_id)) style = 'solid' if cg.enabled else 'dotted' color = 'green' if cg.weight > 0 else 'red' width = str(0.1 + abs(cg.weight / 5.0)) diff --git a/tests/test_configuration b/tests/test_configuration index d2bd5517..94b51162 100644 --- a/tests/test_configuration +++ b/tests/test_configuration @@ -28,16 +28,10 @@ elitism = 1 [genotype compatibility] compatibility_threshold = 3.0 -compatibility_change = 0.0 excess_coefficient = 1.0 disjoint_coefficient = 1.0 weight_coefficient = 0.4 [species] -species_size = 10 survival_threshold = 0.2 -old_threshold = 30 -youth_threshold = 10 -old_penalty = 0.2 -youth_boost = 1.2 max_stagnation = 15 diff --git a/tests/test_genome.py b/tests/test_genome.py index 56b36d40..b58521b0 100644 --- a/tests/test_genome.py +++ b/tests/test_genome.py @@ -4,27 +4,97 @@ from neat.config import Config -def test_recurrent(): +def check_simple(type): local_dir = os.path.dirname(__file__) config = Config(os.path.join(local_dir, 'test_configuration')) - c1 = genome.Genome.create_fully_connected(config, genes.NodeGene, genes.ConnectionGene) + c1 = type.create_fully_connected(config, genes.NodeGene, genes.ConnectionGene) # add two hidden nodes - #c1.add_hidden_nodes(2) + c1.add_hidden_nodes(2) # apply some mutations c1._mutate_add_node() c1._mutate_add_connection() +def test_recurrent(): + check_simple(genome.Genome) + + def test_feed_forward(): + check_simple(genome.FFGenome) + + +def check_self_crossover(type): + # Check that self-crossover produces a genetically identical child (with a different ID). local_dir = os.path.dirname(__file__) config = Config(os.path.join(local_dir, 'test_configuration')) - c2 = genome.FFGenome.create_fully_connected(config, genes.NodeGene, genes.ConnectionGene) + c = type.create_fully_connected(config, genes.NodeGene, genes.ConnectionGene) + c.fitness = 0.0 - # add two hidden nodes - #c2.add_hidden_nodes(2) + cnew = c.crossover(c) + assert cnew.ID != c.ID + assert len(cnew.conn_genes) == len(c.conn_genes) + for kold, vold in cnew.conn_genes.items(): + assert kold in c.conn_genes + vnew = c.conn_genes[kold] + assert vold.is_same_innov(vnew) - # apply some mutations - c2._mutate_add_node() - c2._mutate_add_connection() + assert vnew.weight == vold.weight + assert vnew.in_node_id == vold.in_node_id + assert vnew.out_node_id == vold.out_node_id + assert vnew.enabled == vold.enabled + + assert len(cnew.node_genes) == len(c.node_genes) + for kold, vold in cnew.node_genes.items(): + assert kold in c.node_genes + vnew = c.node_genes[kold] + + assert vnew.ID == vold.ID + assert vnew.type == vold.type + assert vnew.bias == vold.bias + assert vnew.response == vold.response + assert vnew.activation_type == vold.activation_type + + +def test_recurrent_self_crossover(): + check_self_crossover(genome.Genome) + + +def test_feed_forward_self_crossover(): + check_self_crossover(genome.FFGenome) + + +def check_add_connection(type, feed_forward): + local_dir = os.path.dirname(__file__) + config = Config(os.path.join(local_dir, 'test_configuration')) + config.input_nodes = 3 + config.output_nodes = 4 + config.hidden_nodes = 5 + config.feedforward = feed_forward + N = config.input_nodes + config.hidden_nodes + config.output_nodes + + connections = {} + for a in range(100): + g = type.create_unconnected(config, genes.NodeGene, genes.ConnectionGene) + g.add_hidden_nodes(config.hidden_nodes) + for b in range(1000): + g._mutate_add_connection() + for c in g.conn_genes.values(): + connections[c.key] = connections.get(c.key, 0) + 1 + + # TODO: The connections should be returned to the caller and checked + # against the constraints/assumptions particular to the network type. + for i in range(N): + values = [] + for j in range(N): + values.append(connections.get((i, j), 0)) + print("{0:2d}: {1}".format(i, " ".join("{0:3d}".format(x) for x in values))) + + +def test_recurrent_add_connection(): + check_add_connection(genome.Genome, 0) + + +def test_feed_forward_add_connection(): + check_add_connection(genome.FFGenome, 1) \ No newline at end of file