-
Notifications
You must be signed in to change notification settings - Fork 0
/
activation_compare.py
67 lines (57 loc) · 2.02 KB
/
activation_compare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import csv
from typing import List
from locale import setlocale, LC_ALL, str as lstr
from neural_network import Network, Entry, sigmoid, tanh, relu
# use Polish (system) locale (comma as decimal point, for proper Excel formatting)
setlocale(LC_ALL, '')
funcs = (sigmoid, tanh, relu)
# read data
with open("bezdekIris.data", 'r') as file:
samples = []
for row in csv.reader(file):
iris_type = row[4]
iris_outputs = [0, 0, 0]
if iris_type == "Iris-setosa":
iris_outputs[0] = 1
elif iris_type == "Iris-versicolor":
iris_outputs[1] = 1
elif iris_type == "Iris-virginica":
iris_outputs[2] = 1
samples.append(Entry([float(n) for n in row[0:4]], iris_outputs))
def train_network():
# prepare the network
net = Network([4, 4, 3], 0.02)
# store the generated net data for comparison
net_data = net.export_data()
def train(activation_f):
net.import_data(net_data) # restore original for comparison purposes
net.set_activation_f(activation_f)
errors_data = []
for i, error in enumerate(net.teach_loop(samples), start=1):
errors_data.append(error)
if i >= cycles:
break
return errors_data
funs_data = []
for f in funcs:
print(f.__name__)
funs_data.append(train(f))
return funs_data
times = 10
cycles = 100
# repeat for average
avg_data = train_network()
for _ in range(times-1):
data = train_network()
# consolidate
for i, new_func_data, avg_func_data in zip(range(len(funcs)), data, avg_data):
avg_data[i] = [sum(d) for d in zip(new_func_data, avg_func_data)]
# divide and convert to polish notation (use locale)
for i, func_data in enumerate(avg_data):
avg_data[i] = [lstr(error / times) for error in func_data]
# transpose for excel
avg_data = list(zip(*avg_data))
# save
with open("neural_data.csv", 'w') as file:
writer = csv.writer(file, delimiter=';') # delimit with semicolons
writer.writerows(avg_data)