forked from CodeReclaimers/neat-python
-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtest_xor_example.py
98 lines (79 loc) · 3.59 KB
/
test_xor_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from __future__ import print_function
import os
import neat
def test_xor_example_uniform_weights():
test_xor_example(uniform_weights=True)
def test_xor_example(uniform_weights=False):
# 2-input XOR inputs and expected outputs.
xor_inputs = [(0.0, 0.0), (0.0, 1.0), (1.0, 0.0), (1.0, 1.0)]
xor_outputs = [(0.0,), (1.0,), (1.0,), (0.0,)]
def eval_genomes(genomes, config):
for genome_id, genome in genomes:
genome.fitness = 1.0
net = neat.nn.FeedForwardNetwork.create(genome, config)
for xi, xo in zip(xor_inputs, xor_outputs):
output = net.activate(xi)
genome.fitness -= (output[0] - xo[0]) ** 2
# Determine path to configuration file. This path manipulation is
# here so that the script will run successfully regardless of the
# current working directory.
local_dir = os.path.dirname(__file__)
config_path = os.path.join(local_dir, 'test_configuration')
# Load configuration.
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
neat.DefaultSpeciesSet, neat.DefaultStagnation,
config_path)
if uniform_weights:
config.genome_config.weight_init_type = 'uniform'
filename_prefix = 'neat-checkpoint-test_xor_uniform-'
else:
filename_prefix = 'neat-checkpoint-test_xor-'
# Create the population, which is the top-level object for a NEAT run.
p = neat.Population(config)
# Add a stdout reporter to show progress in the terminal.
p.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
p.add_reporter(stats)
checkpointer = neat.Checkpointer(25, 10, filename_prefix)
p.add_reporter(checkpointer)
# Run for up to 100 generations, allowing extinction.
winner = None
try:
winner = p.run(eval_genomes, 100)
except neat.CompleteExtinctionException as e:
pass
assert len(stats.get_fitness_median()), "Nothing returned from get_fitness_median()"
if winner:
if uniform_weights:
print('\nUsing uniform weight initialization:')
# Display the winning genome.
print('\nBest genome:\n{!s}'.format(winner))
# Show output of the most fit genome against training data.
print('\nOutput:')
winner_net = neat.nn.FeedForwardNetwork.create(winner, config)
for xi, xo in zip(xor_inputs, xor_outputs):
output = winner_net.activate(xi)
print("input {!r}, expected output {!r}, got {!r}".format(xi, xo, output))
if (checkpointer.last_generation_checkpoint >= 0) and (checkpointer.last_generation_checkpoint < 100):
filename = '{0}{1}'.format(filename_prefix,checkpointer.last_generation_checkpoint)
print("Restoring from {!s}".format(filename))
p2 = neat.checkpoint.Checkpointer.restore_checkpoint(filename)
p2.add_reporter(neat.StdOutReporter(True))
stats2 = neat.StatisticsReporter()
p2.add_reporter(stats2)
winner2 = None
try:
winner2 = p2.run(eval_genomes, (100-checkpointer.last_generation_checkpoint))
except neat.CompleteExtinctionException:
pass
if winner2:
if not winner:
raise Exception("Had winner2 without first-try winner")
elif winner:
raise Exception("Had first-try winner without winner2")
if __name__ == '__main__':
test_xor_example()
test_xor_example_uniform_weights()
test_xor_example_multiparam_relu()
test_xor_example_multiparam_sigmoid_or_relu()
test_xor_example_multiparam_aggregation()