Skip to content

Commit

Permalink
Merged @mstechly's statistics and index reset pull request.
Browse files Browse the repository at this point in the history
Simplified Population statistics collection, and keep per-species fitness data.
Simplified visualize and statistics functions to take the Population object directly, and factored out common statistical operations on Population statistics into statistics.py.
Updated basic XOR example to show usage of statistics functions.
  • Loading branch information
CodeReclaimers committed Dec 23, 2015
1 parent c7f47b3 commit fd5f2d7
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 83 deletions.
4 changes: 2 additions & 2 deletions examples/pole_balancing/single_pole/ctrnn_evolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def fitness_function(genomes):
print(winner)

# Plot the evolution of the best/average fitness.
visualize.plot_stats(pop.most_fit_genomes, pop.fitness_scores, ylog=True, filename="ctrnn_fitness.svg")
visualize.plot_stats(pop, ylog=True, filename="ctrnn_fitness.svg")
# Visualizes speciation
visualize.plot_species(pop.species_log, filename="ctrnn_speciation.svg")
visualize.plot_species(pop, filename="ctrnn_speciation.svg")
# Visualize the best network.
visualize.draw_net(winner, view=True, filename="ctrnn_winner.gv")
4 changes: 2 additions & 2 deletions examples/pole_balancing/single_pole/nn_evolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def fitness_function(genomes):
pickle.dump(winner, f)

# Plot the evolution of the best/average fitness.
visualize.plot_stats(pop.most_fit_genomes, pop.fitness_scores, ylog=True, filename="nn_fitness.svg")
visualize.plot_stats(pop, ylog=True, filename="nn_fitness.svg")
# Visualizes speciation
visualize.plot_species(pop.species_log, filename="nn_speciation.svg")
visualize.plot_species(pop, filename="nn_speciation.svg")
# Visualize the best network.
visualize.draw_net(winner, view=True, filename="nn_winner.gv")
11 changes: 7 additions & 4 deletions examples/xor/xor2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
""" 2-input XOR example """
from __future__ import print_function

from neat import nn, population, visualize
from neat import nn, population, statistics, visualize

xor_inputs = [[0, 0], [0, 1], [1, 0], [1, 1]]
xor_outputs = [0, 1, 1, 0]
Expand Down Expand Up @@ -35,7 +35,10 @@ def eval_fitness(genomes):
output = winner_net.serial_activate(inputs)
print("expected {0:1.5f} got {1:1.5f}".format(expected, output[0]))

# Visualize the winner network and plot statistics.
visualize.plot_stats(pop.most_fit_genomes, pop.fitness_scores)
visualize.plot_species(pop.species_log)
# Visualize the winner network and plot/log statistics.
visualize.plot_stats(pop)
visualize.plot_species(pop)
visualize.draw_net(winner, view=True)
statistics.save_stats(pop)
statistics.save_species_count(pop)
statistics.save_species_fitness(pop)
4 changes: 2 additions & 2 deletions examples/xor/xor2_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def fitness(genomes):
print("{0:1.5f} \t {1:1.5f}".format(xor_outputs[i], output[0]))

# Visualize the winner network and plot statistics.
visualize.plot_stats(pop.most_fit_genomes, pop.fitness_scores)
visualize.plot_species(pop.species_log)
visualize.plot_stats(pop)
visualize.plot_species(pop)
visualize.draw_net(winner, view=True)


Expand Down
4 changes: 2 additions & 2 deletions examples/xor/xor2_spiking.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def run():
# Visualize the winner network and plot statistics.
winner = pop.most_fit_genomes[-1]
visualize.draw_net(winner, view=True, node_names={0: 'A', 1: 'B', 2: 'Out1', 3: 'Out2'})
visualize.plot_stats(pop.most_fit_genomes, pop.fitness_scores)
visualize.plot_species(pop.species_log)
visualize.plot_stats(pop)
visualize.plot_species(pop)

# Verify network output against training data.
print('\nBest network output:')
Expand Down
18 changes: 13 additions & 5 deletions neat/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def __init__(self, config, checkpoint_file=None, initial_population=None,

self.population = None
self.species = []
self.species_log = []
self.fitness_scores = []
self.generation_statistics = []
self.most_fit_genomes = []
self.generation = -1
self.total_evaluations = 0
Expand All @@ -50,6 +49,10 @@ def __init__(self, config, checkpoint_file=None, initial_population=None,
# Partition the population into species based on current configuration.
self._speciate()

@staticmethod
def clear_indexer(cls):
Species.clear_indexer()

def _load_checkpoint(self, checkpoint):
'''Resumes the simulation from a previous saved point.'''
with gzip.open(checkpoint) as f:
Expand Down Expand Up @@ -130,10 +133,15 @@ def _speciate(self):

def _log_stats(self):
""" Gather data for visualization/reporting purposes. """
species_sizes = dict((s.ID, len(s.members)) for s in self.species)
self.species_log.append(species_sizes)
# Keep a deep copy of the best genome, so that future modifications to the genome
# do not produce an unexpected change in statistics.
self.most_fit_genomes.append(copy.deepcopy(max(self.population)))
self.fitness_scores.append([c.fitness for c in self.population])

# Store the fitnesses of the members of each currently active species.
species_stats = {}
for s in self.species:
species_stats[s.ID] = [c.fitness for c in s.members]
self.generation_statistics.append(species_stats)

def epoch(self, fitness_function, n, report=True, save_best=False, checkpoint_interval=10,
checkpoint_generation=None):
Expand Down
8 changes: 4 additions & 4 deletions neat/species.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ class Species(object):
""" A collection of genetically similar individuals."""
indexer = Indexer(1)

@classmethod
def clear_indexer(cls):
cls.indexer.clear()

def __init__(self, first_individual, previous_id=None):
self.representative = first_individual
self.ID = Species.indexer.next(previous_id)
Expand Down Expand Up @@ -75,7 +79,3 @@ def reproduce(self, config):
self.representative = random.choice(offspring)

return offspring

def clearIndexer(self):
Species.indexer.clear()

102 changes: 63 additions & 39 deletions neat/statistics.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,78 @@
# -*- coding: UTF-8 -*-
from __future__ import print_function
import warnings

import csv

try:
import numpy as np
except ImportError:
np = None
warnings.warn('Could not import optional dependency NumPy.')
from neat.math_util import mean


def get_average_fitness(population):
avg_fitness = []
for stats in population.generation_statistics:
scores = []
for fitness in stats.values():
scores.extend(fitness)
avg_fitness.append(mean(scores))

return avg_fitness


def get_species_sizes(population):
all_species = set()
for gen_data in population.generation_statistics:
all_species = all_species.union(gen_data.keys())

max_species = max(all_species)
species_counts = []
for gen_data in population.generation_statistics:
species = [len(gen_data.get(sid, [])) for sid in range(1, max_species + 1)]
species_counts.append(species)

return species_counts

def save_stats(best_genomes, avg_scores, ylog=False, view=False, filename='fitness_history.csv'):
""" Saves the population's average and best fitness. """
csvfile = open(filename, 'wb')
statWriter = csv.writer(csvfile, delimiter=' ')

generation = range(len(best_genomes))
fitness = [c.fitness for c in best_genomes]
def get_species_fitness(population, null_value=''):
all_species = set()
for gen_data in population.generation_statistics:
all_species = all_species.union(gen_data.keys())

for i in generation:
statWriter.writerow([fitness[i], avg_scores[i]])
max_species = max(all_species)
species_fitness = []
for gen_data in population.generation_statistics:
member_fitness = [gen_data.get(sid, []) for sid in range(1, max_species + 1)]
fitness = []
for mf in member_fitness:
if mf:
fitness.append(mean(mf))
else:
fitness.append(null_value)
species_fitness.append(fitness)

csvfile.close()
return species_fitness


def save_species_count(species_log, view=False, filename='speciation.csv'):
""" Visualizes speciation throughout evolution. """
csvfile = open(filename, 'wb')
statWriter = csv.writer(csvfile, delimiter=' ')
def save_stats(population, delimiter=' ', filename='fitness_history.csv'):
""" Saves the population's best and average fitness. """
with open(filename, 'w') as f:
w = csv.writer(f, delimiter=delimiter)

best_fitness = [c.fitness for c in population.most_fit_genomes]
avg_fitness = get_average_fitness(population)
for best, avg in zip(best_fitness, avg_fitness):
w.writerow([best, avg])

num_generations = len(species_log)
num_species = max(map(len, species_log))
curves = []
for gen in species_log:
species = [0] * num_species
species[:len(gen)] = gen
statWriter.writerow(species)

csvfile.close()
def save_species_count(population, delimiter=' ', filename='speciation.csv'):
""" Log speciation throughout evolution. """
with open(filename, 'w') as f:
w = csv.writer(f, delimiter=delimiter)
for s in get_species_sizes(population):
w.writerow(s)

def save_species_fitness(species_fitness_log, view=False, filename='species_fitness.csv'):
""" Visualizes speciation throughout evolution. """
csvfile = open(filename, 'wb')
statWriter = csv.writer(csvfile, delimiter=' ')
num_generations = len(species_fitness_log)
num_species = max(map(len, species_fitness_log))
curves = []
for gen in species_fitness_log:
species = ["NA"] * num_species
species[:len(gen)] = gen
statWriter.writerow(species)

csvfile.close()
def save_species_fitness(population, delimiter=' ', null_value='NA', filename='species_fitness.csv'):
""" Log species' average fitness throughout evolution. """
with open(filename, 'w') as f:
w = csv.writer(f, delimiter=delimiter)
for s in get_species_fitness(population, null_value):
w.writerow(s)
36 changes: 13 additions & 23 deletions neat/visualize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# -*- coding: UTF-8 -*-
from __future__ import print_function

import warnings

from neat.statistics import get_average_fitness, get_species_sizes

try:
import graphviz
except ImportError:
Expand All @@ -20,23 +23,19 @@
np = None
warnings.warn('Could not import optional dependency NumPy.')

from neat.math_util import mean


def plot_stats(best_genomes, fitness_scores, ylog=False, view=False, filename='avg_fitness.svg'):
def plot_stats(population, ylog=False, view=False, filename='avg_fitness.svg'):
""" Plots the population's average and best fitness. """
if plt is None:
warnings.warn("This display is not available due to a missing optional dependency (matplotlib)")
return

generation = range(len(best_genomes))

fitness = [c.fitness for c in best_genomes]

avg_scores = [mean(f) for f in fitness_scores]
generation = range(len(population.most_fit_genomes))
best_fitness = [c.fitness for c in population.most_fit_genomes]
avg_fitness = get_average_fitness(population)

plt.plot(generation, avg_scores, 'b-', label="average")
plt.plot(generation, fitness, 'r-', label="best")
plt.plot(generation, avg_fitness, 'b-', label="average")
plt.plot(generation, best_fitness, 'r-', label="best")

plt.title("Population's average and best fitness")
plt.xlabel("Generations")
Expand Down Expand Up @@ -95,24 +94,15 @@ def plot_spikes(spikes, view=False, filename=None, title=None):
plt.close()


def plot_species(species_log, view=False, filename='speciation.svg'):
def plot_species(population, view=False, filename='speciation.svg'):
""" Visualizes speciation throughout evolution. """
if plt is None:
warnings.warn("This display is not available due to a missing optional dependency (matplotlib)")
return

num_generations = len(species_log)
all_species = set()
for gen_data in species_log:
for sid, scount in gen_data.items():
all_species.add(sid)

max_species = max(all_species)
curves = []
for gen_data in species_log:
species = [gen_data.get(sid, 0) for sid in range(max_species + 1)]
curves.append(np.array(species))
curves = np.array(curves).T
species_sizes = get_species_sizes(population)
num_generations = len(species_sizes)
curves = np.array(species_sizes).T

fig, ax = plt.subplots()
ax.stackplot(range(num_generations), *curves)
Expand Down

0 comments on commit fd5f2d7

Please sign in to comment.