-
Notifications
You must be signed in to change notification settings - Fork 0
/
Train.go
82 lines (67 loc) · 1.9 KB
/
Train.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
package NeuralNetworkGo
import (
"log"
Matrix "github.com/9init/NeuralNetworkGo/Matrix"
)
func (neural *NeuralN) Train(inputArray, targetArray []float64) {
neural.check(inputArray, targetArray)
targets := Matrix.NewFromArray(targetArray)
inputs := Matrix.NewFromArray(inputArray)
hidden, _ := neural.WeightIH.StaticDotProduct(inputs)
hidden.AddFromMatrix(neural.BiasH)
hidden.Map(sigmoid)
outputs, err := neural.WeightHO.StaticDotProduct(hidden)
outputs.AddFromMatrix(neural.BiasO)
outputs.Map(sigmoid)
// calculate weights between hidden and outputs
output_errors := targets
output_errors.SuptractMatrix(outputs)
// Calculate output gradient
// X * (1 - X) -> dsigmoid
outputs_G := outputs
outputs_G.Map(dsigmoid)
_, err = outputs_G.HadProduct(output_errors)
if err != nil {
log.Fatal(err)
}
outputs_G.Multiply(neural.LearningRate)
// Calculate delta
// Learning rate * Error *
hidden_T := hidden
hidden_T.Transpose()
weights_HO_G, err := outputs_G.StaticDotProduct(hidden_T)
if err != nil {
log.Fatal(err)
}
// Adjust the weight by delta
neural.WeightHO.AddFromMatrix(weights_HO_G)
// Adjust the bias by gradient
neural.BiasO.AddFromMatrix(outputs_G)
// Calculate hidden layer error
whoT := neural.WeightHO
whoT.Transpose()
hidden_errors, err := whoT.StaticDotProduct(output_errors)
if err != nil {
log.Fatal(err)
}
// Calculate hidden gradient
hidden_G := hidden
hidden_G.Map(dsigmoid)
//fmt.Println(hidden_G, "\n", hidden_errors)
_, err = hidden_G.HadProduct(hidden_errors)
if err != nil {
log.Fatal(err)
}
hidden_G.Multiply(neural.LearningRate)
// Calculate input->hidden deltas
input_T := inputs
input_T.Transpose()
weight_HI_Delta, _ := hidden_G.StaticDotProduct(input_T)
if err != nil {
log.Fatal(err)
}
// Adjust the weight by delta
neural.WeightIH.AddFromMatrix(weight_HI_Delta)
// Adjust the bias by grediant
neural.BiasH.AddFromMatrix(hidden_G)
}