forked from yaricom/goNEAT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
neat.go
326 lines (289 loc) · 11.3 KB
/
neat.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
// Package neat implements the NeuroEvolution of Augmenting Topologies (NEAT) method, which can be used to evolve
// specific Artificial Neural Networks from scratch using genetic algorithms.
package neat
import (
"context"
"fmt"
"github.com/AISystemsInc/goNEAT/v2/neat/math"
"github.com/pkg/errors"
"github.com/spf13/cast"
"gopkg.in/yaml.v3"
"io"
"io/ioutil"
"strconv"
"strings"
)
// GenomeCompatibilityMethod defines the method to calculate genomes compatibility
type GenomeCompatibilityMethod string
const (
GenomeCompatibilityMethodLinear GenomeCompatibilityMethod = "linear"
GenomeCompatibilityMethodFast GenomeCompatibilityMethod = "fast"
)
// Validate is to check if this genome compatibility method supported by algorithm
func (g GenomeCompatibilityMethod) Validate() error {
if g != GenomeCompatibilityMethodLinear && g != GenomeCompatibilityMethodFast {
return errors.Errorf("unsupported genome compatibility method: [%s]", g)
}
return nil
}
// EpochExecutorType is to define the type of epoch evaluator
type EpochExecutorType string
const (
EpochExecutorTypeSequential EpochExecutorType = "sequential"
EpochExecutorTypeParallel EpochExecutorType = "parallel"
)
// Validate is to check is this executor type is supported by algorithm
func (e EpochExecutorType) Validate() error {
if e != EpochExecutorTypeSequential && e != EpochExecutorTypeParallel {
return errors.Errorf("unsupported epoch executor type: [%s]", e)
}
return nil
}
// Options The NEAT algorithm options.
type Options struct {
// Probability of mutating a single trait param
TraitParamMutProb float64 `yaml:"trait_param_mut_prob"`
// Power of mutation on a single trait param
TraitMutationPower float64 `yaml:"trait_mutation_power"`
// The power of a link weight mutation
WeightMutPower float64 `yaml:"weight_mut_power"`
// These 3 global coefficients are used to determine the formula for
// computing the compatibility between 2 genomes. The formula is:
// disjoint_coeff * pdg + excess_coeff * peg + mutdiff_coeff * mdmg.
// See the compatibility method in the Genome class for more info
// They can be thought of as the importance of disjoint Genes,
// excess Genes, and parametric difference between Genes of the
// same function, respectively.
DisjointCoeff float64 `yaml:"disjoint_coeff"`
ExcessCoeff float64 `yaml:"excess_coeff"`
MutdiffCoeff float64 `yaml:"mutdiff_coeff"`
// This global tells compatibility threshold under which
// two Genomes are considered the same species
CompatThreshold float64 `yaml:"compat_threshold"`
/* Globals involved in the epoch cycle - mating, reproduction, etc.. */
// How much does age matter? Gives a fitness boost up to some young age (niching).
// If it is 1, then young species get no fitness boost.
AgeSignificance float64 `yaml:"age_significance"`
// Percent of average fitness for survival, how many get to reproduce based on survival_thresh * pop_size
SurvivalThresh float64 `yaml:"survival_thresh"`
// Probabilities of a non-mating reproduction
MutateOnlyProb float64 `yaml:"mutate_only_prob"`
MutateRandomTraitProb float64 `yaml:"mutate_random_trait_prob"`
MutateLinkTraitProb float64 `yaml:"mutate_link_trait_prob"`
MutateNodeTraitProb float64 `yaml:"mutate_node_trait_prob"`
MutateLinkWeightsProb float64 `yaml:"mutate_link_weights_prob"`
MutateToggleEnableProb float64 `yaml:"mutate_toggle_enable_prob"`
MutateGeneReenableProb float64 `yaml:"mutate_gene_reenable_prob"`
MutateAddNodeProb float64 `yaml:"mutate_add_node_prob"`
MutateAddLinkProb float64 `yaml:"mutate_add_link_prob"`
// probability of mutation involving disconnected inputs connection
MutateConnectSensors float64 `yaml:"mutate_connect_sensors"`
// Probabilities of a mate being outside species
InterspeciesMateRate float64 `yaml:"interspecies_mate_rate"`
MateMultipointProb float64 `yaml:"mate_multipoint_prob"`
MateMultipointAvgProb float64 `yaml:"mate_multipoint_avg_prob"`
MateSinglepointProb float64 `yaml:"mate_singlepoint_prob"`
// Prob. of mating without mutation
MateOnlyProb float64 `yaml:"mate_only_prob"`
// Probability of forcing selection of ONLY links that are naturally recurrent
RecurOnlyProb float64 `yaml:"recur_only_prob"`
// Size of population
PopSize int `yaml:"pop_size"`
// Age when Species starts to be penalized
DropOffAge int `yaml:"dropoff_age"`
// Number of tries mutate_add_link will attempt to find an open link
NewLinkTries int `yaml:"newlink_tries"`
// Tells to print population to file every n generations
PrintEvery int `yaml:"print_every"`
// The number of babies to stolen off to the champions
BabiesStolen int `yaml:"babies_stolen"`
// The number of runs to average over in an experiment
NumRuns int `yaml:"num_runs"`
// The number of epochs (generations) to execute training
NumGenerations int `yaml:"num_generations"`
// The epoch's executor type to apply (sequential, parallel)
EpochExecutorType EpochExecutorType `yaml:"epoch_executor"`
// The genome compatibility testing method to use (linear, fast (make sense for large genomes))
GenCompatMethod GenomeCompatibilityMethod `yaml:"genome_compat_method"`
// The neuron nodes activation functions list to choose from
NodeActivators []math.NodeActivationType `yaml:"-"`
// The probabilities of selection of the specific node activator function
NodeActivatorsProb []float64 `yaml:"-"`
// NodeActivatorsWithProbs the list of supported node activation with probability of each one
NodeActivatorsWithProbs []string `yaml:"node_activators"`
// LogLevel the log output details level
LogLevel string `yaml:"log_level"`
}
// RandomNodeActivationType Returns next random node activation type among registered with this context
func (c *Options) RandomNodeActivationType() (math.NodeActivationType, error) {
// quick check for the most cases
if len(c.NodeActivators) == 1 {
return c.NodeActivators[0], nil
}
// find next random
index := math.SingleRouletteThrow(c.NodeActivatorsProb)
if index < 0 || index >= len(c.NodeActivators) {
return 0, fmt.Errorf("unexpected error when trying to find random node activator, activator index: %d", index)
}
return c.NodeActivators[index], nil
}
// set default values for activator type and its probability of selection
func (c *Options) initNodeActivators() (err error) {
if len(c.NodeActivatorsWithProbs) == 0 {
c.NodeActivators = []math.NodeActivationType{math.SigmoidSteepenedActivation}
c.NodeActivatorsProb = []float64{1.0}
return nil
}
// create activators
actFns := c.NodeActivatorsWithProbs
c.NodeActivators = make([]math.NodeActivationType, len(actFns))
c.NodeActivatorsProb = make([]float64, len(actFns))
for i, line := range actFns {
fields := strings.Fields(line)
if c.NodeActivators[i], err = math.NodeActivators.ActivationTypeFromName(fields[0]); err != nil {
return err
}
if prob, err := strconv.ParseFloat(fields[1], 64); err != nil {
return err
} else {
c.NodeActivatorsProb[i] = prob
}
}
return nil
}
// Validate is to validate that this options has valid values
func (c *Options) Validate() error {
if err := c.EpochExecutorType.Validate(); err != nil {
return err
}
if err := c.GenCompatMethod.Validate(); err != nil {
return err
}
return nil
}
// NeatContext is to get Context which carries NEAT options inside to be propagated
func (c *Options) NeatContext() context.Context {
return NewContext(context.Background(), c)
}
// LoadYAMLOptions is to load NEAT options encoded as YAML file
func LoadYAMLOptions(r io.Reader) (*Options, error) {
content, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
// read options
var opts Options
if err = yaml.Unmarshal(content, &opts); err != nil {
return nil, errors.Wrap(err, "failed to decode NEAT options from YAML")
}
// initialize logger
if err = InitLogger(opts.LogLevel); err != nil {
return nil, errors.Wrap(err, "failed to initialize logger")
}
// read node activators
if err = opts.initNodeActivators(); err != nil {
return nil, errors.Wrap(err, "failed to read node activators")
}
if err = opts.Validate(); err != nil {
return nil, errors.Wrap(err, "invalid NEAT options")
}
return &opts, nil
}
// LoadNeatOptions Loads NEAT options configuration from provided reader encode in plain text format (.neat)
func LoadNeatOptions(r io.Reader) (*Options, error) {
c := &Options{}
// read configuration
var name string
var param string
for {
_, err := fmt.Fscanf(r, "%s %v", &name, ¶m)
if err == io.EOF {
break
}
switch name {
case "trait_param_mut_prob":
c.TraitParamMutProb = cast.ToFloat64(param)
case "trait_mutation_power":
c.TraitMutationPower = cast.ToFloat64(param)
case "weight_mut_power":
c.WeightMutPower = cast.ToFloat64(param)
case "disjoint_coeff":
c.DisjointCoeff = cast.ToFloat64(param)
case "excess_coeff":
c.ExcessCoeff = cast.ToFloat64(param)
case "mutdiff_coeff":
c.MutdiffCoeff = cast.ToFloat64(param)
case "compat_threshold":
c.CompatThreshold = cast.ToFloat64(param)
case "age_significance":
c.AgeSignificance = cast.ToFloat64(param)
case "survival_thresh":
c.SurvivalThresh = cast.ToFloat64(param)
case "mutate_only_prob":
c.MutateOnlyProb = cast.ToFloat64(param)
case "mutate_random_trait_prob":
c.MutateRandomTraitProb = cast.ToFloat64(param)
case "mutate_link_trait_prob":
c.MutateLinkTraitProb = cast.ToFloat64(param)
case "mutate_node_trait_prob":
c.MutateNodeTraitProb = cast.ToFloat64(param)
case "mutate_link_weights_prob":
c.MutateLinkWeightsProb = cast.ToFloat64(param)
case "mutate_toggle_enable_prob":
c.MutateToggleEnableProb = cast.ToFloat64(param)
case "mutate_gene_reenable_prob":
c.MutateGeneReenableProb = cast.ToFloat64(param)
case "mutate_add_node_prob":
c.MutateAddNodeProb = cast.ToFloat64(param)
case "mutate_add_link_prob":
c.MutateAddLinkProb = cast.ToFloat64(param)
case "mutate_connect_sensors":
c.MutateConnectSensors = cast.ToFloat64(param)
case "interspecies_mate_rate":
c.InterspeciesMateRate = cast.ToFloat64(param)
case "mate_multipoint_prob":
c.MateMultipointProb = cast.ToFloat64(param)
case "mate_multipoint_avg_prob":
c.MateMultipointAvgProb = cast.ToFloat64(param)
case "mate_singlepoint_prob":
c.MateSinglepointProb = cast.ToFloat64(param)
case "mate_only_prob":
c.MateOnlyProb = cast.ToFloat64(param)
case "recur_only_prob":
c.RecurOnlyProb = cast.ToFloat64(param)
case "pop_size":
c.PopSize = cast.ToInt(param)
case "dropoff_age":
c.DropOffAge = cast.ToInt(param)
case "newlink_tries":
c.NewLinkTries = cast.ToInt(param)
case "print_every":
c.PrintEvery = cast.ToInt(param)
case "babies_stolen":
c.BabiesStolen = cast.ToInt(param)
case "num_runs":
c.NumRuns = cast.ToInt(param)
case "num_generations":
c.NumGenerations = cast.ToInt(param)
case "epoch_executor":
c.EpochExecutorType = EpochExecutorType(param)
case "genome_compat_method":
c.GenCompatMethod = GenomeCompatibilityMethod(param)
case "log_level":
c.LogLevel = param
default:
return nil, errors.Errorf("unknown configuration parameter found: %s = %s", name, param)
}
}
// initialize logger
if err := InitLogger(c.LogLevel); err != nil {
return nil, errors.Wrap(err, "failed to initialize logger")
}
if err := c.initNodeActivators(); err != nil {
return nil, err
}
if err := c.Validate(); err != nil {
return nil, err
}
return c, nil
}