-
-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merged @mstechly's statistics and index reset pull request.
Simplified Population statistics collection, and keep per-species fitness data. Simplified visualize and statistics functions to take the Population object directly, and factored out common statistical operations on Population statistics into statistics.py. Updated basic XOR example to show usage of statistics functions.
- Loading branch information
1 parent
c7f47b3
commit fd5f2d7
Showing
9 changed files
with
108 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,54 +1,78 @@ | ||
# -*- coding: UTF-8 -*- | ||
from __future__ import print_function | ||
import warnings | ||
|
||
import csv | ||
|
||
try: | ||
import numpy as np | ||
except ImportError: | ||
np = None | ||
warnings.warn('Could not import optional dependency NumPy.') | ||
from neat.math_util import mean | ||
|
||
|
||
def get_average_fitness(population): | ||
avg_fitness = [] | ||
for stats in population.generation_statistics: | ||
scores = [] | ||
for fitness in stats.values(): | ||
scores.extend(fitness) | ||
avg_fitness.append(mean(scores)) | ||
|
||
return avg_fitness | ||
|
||
|
||
def get_species_sizes(population): | ||
all_species = set() | ||
for gen_data in population.generation_statistics: | ||
all_species = all_species.union(gen_data.keys()) | ||
|
||
max_species = max(all_species) | ||
species_counts = [] | ||
for gen_data in population.generation_statistics: | ||
species = [len(gen_data.get(sid, [])) for sid in range(1, max_species + 1)] | ||
species_counts.append(species) | ||
|
||
return species_counts | ||
|
||
def save_stats(best_genomes, avg_scores, ylog=False, view=False, filename='fitness_history.csv'): | ||
""" Saves the population's average and best fitness. """ | ||
csvfile = open(filename, 'wb') | ||
statWriter = csv.writer(csvfile, delimiter=' ') | ||
|
||
generation = range(len(best_genomes)) | ||
fitness = [c.fitness for c in best_genomes] | ||
def get_species_fitness(population, null_value=''): | ||
all_species = set() | ||
for gen_data in population.generation_statistics: | ||
all_species = all_species.union(gen_data.keys()) | ||
|
||
for i in generation: | ||
statWriter.writerow([fitness[i], avg_scores[i]]) | ||
max_species = max(all_species) | ||
species_fitness = [] | ||
for gen_data in population.generation_statistics: | ||
member_fitness = [gen_data.get(sid, []) for sid in range(1, max_species + 1)] | ||
fitness = [] | ||
for mf in member_fitness: | ||
if mf: | ||
fitness.append(mean(mf)) | ||
else: | ||
fitness.append(null_value) | ||
species_fitness.append(fitness) | ||
|
||
csvfile.close() | ||
return species_fitness | ||
|
||
|
||
def save_species_count(species_log, view=False, filename='speciation.csv'): | ||
""" Visualizes speciation throughout evolution. """ | ||
csvfile = open(filename, 'wb') | ||
statWriter = csv.writer(csvfile, delimiter=' ') | ||
def save_stats(population, delimiter=' ', filename='fitness_history.csv'): | ||
""" Saves the population's best and average fitness. """ | ||
with open(filename, 'w') as f: | ||
w = csv.writer(f, delimiter=delimiter) | ||
|
||
best_fitness = [c.fitness for c in population.most_fit_genomes] | ||
avg_fitness = get_average_fitness(population) | ||
for best, avg in zip(best_fitness, avg_fitness): | ||
w.writerow([best, avg]) | ||
|
||
num_generations = len(species_log) | ||
num_species = max(map(len, species_log)) | ||
curves = [] | ||
for gen in species_log: | ||
species = [0] * num_species | ||
species[:len(gen)] = gen | ||
statWriter.writerow(species) | ||
|
||
csvfile.close() | ||
def save_species_count(population, delimiter=' ', filename='speciation.csv'): | ||
""" Log speciation throughout evolution. """ | ||
with open(filename, 'w') as f: | ||
w = csv.writer(f, delimiter=delimiter) | ||
for s in get_species_sizes(population): | ||
w.writerow(s) | ||
|
||
def save_species_fitness(species_fitness_log, view=False, filename='species_fitness.csv'): | ||
""" Visualizes speciation throughout evolution. """ | ||
csvfile = open(filename, 'wb') | ||
statWriter = csv.writer(csvfile, delimiter=' ') | ||
num_generations = len(species_fitness_log) | ||
num_species = max(map(len, species_fitness_log)) | ||
curves = [] | ||
for gen in species_fitness_log: | ||
species = ["NA"] * num_species | ||
species[:len(gen)] = gen | ||
statWriter.writerow(species) | ||
|
||
csvfile.close() | ||
def save_species_fitness(population, delimiter=' ', null_value='NA', filename='species_fitness.csv'): | ||
""" Log species' average fitness throughout evolution. """ | ||
with open(filename, 'w') as f: | ||
w = csv.writer(f, delimiter=delimiter) | ||
for s in get_species_fitness(population, null_value): | ||
w.writerow(s) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters