-
Notifications
You must be signed in to change notification settings - Fork 0
/
NeuralNetwork.java
81 lines (71 loc) · 3.37 KB
/
NeuralNetwork.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
// A class for a 3 layered neural network
import java.util.Random;
class NeuralNetwork implements java.io.Serializable{
int inputCount;
int hiddenCount;
int outputCount;
Matrix weightsIH;
Matrix weightsHO;
Matrix biasH;
Matrix biasO;
public NeuralNetwork(int inputCount, int hiddenCount, int outputCount){
this.inputCount = inputCount;
this.hiddenCount = hiddenCount;
this.outputCount = outputCount;
this.weightsIH = new Matrix(this.hiddenCount, this.inputCount);
this.weightsHO = new Matrix(this.outputCount, this.hiddenCount);
this.biasH = new Matrix(this.hiddenCount, 1);
this.biasO = new Matrix(this.outputCount, 1);
this.weightsIH.randomize();
this.weightsHO.randomize();
this.biasH.randomize();
this.biasO.randomize();
}
public Matrix predict(Matrix inputMatrix) throws Exception{
//Multiply input matrix by wights and get hidden matrix
Matrix hiddenMatrix = Matrix.matrixMultiplication(this.weightsIH, inputMatrix);
//Add bias to all hidden matrix nodes
hiddenMatrix.add(biasH);
//Apply activation function to make all values btwn 0 and 1
hiddenMatrix.mapSigmoid();
//Reapeat process for hidden/output layer weights
Matrix outputMatrix = Matrix.matrixMultiplication(this.weightsHO, hiddenMatrix);
outputMatrix.add(biasO);
outputMatrix.mapSigmoid();
//Matrix.print(outputMatrix);
return outputMatrix;
}
public static NeuralNetwork[] crossover(NeuralNetwork dad, NeuralNetwork mom) throws Exception{
// lol this is essentially preforming the crossing over proccess
Random rand = new Random();
Matrix[] dadSplit;
Matrix[] momSplit;
NeuralNetwork baby1 = new NeuralNetwork(dad.inputCount, dad.hiddenCount, dad.outputCount);
NeuralNetwork baby2 = new NeuralNetwork(dad.inputCount, dad.hiddenCount, dad.outputCount);
int cutIH = rand.nextInt(dad.weightsIH.rows);
int cutHO = rand.nextInt(dad.weightsHO.rows);
// crossover of intput-hidden weights
dadSplit = Matrix.transverseCut(dad.weightsIH, cutIH);
momSplit = Matrix.transverseCut(mom.weightsIH, cutIH);
baby1.weightsIH = Matrix.combineColumnsMatrices(dadSplit[0], momSplit[1]);
baby2.weightsIH = Matrix.combineColumnsMatrices(momSplit[0], dadSplit[1]);
// crossover of hidden-output weights
dadSplit = Matrix.transverseCut(dad.weightsHO, cutHO);
momSplit = Matrix.transverseCut(mom.weightsHO, cutHO);
baby1.weightsHO = Matrix.combineColumnsMatrices(dadSplit[0], momSplit[1]);
baby2.weightsHO = Matrix.combineColumnsMatrices(momSplit[0], dadSplit[1]);
return new NeuralNetwork[] {baby1, baby2};
}
public void mutate(double chance) throws Exception{
this.weightsIH.mutate(chance);
this.weightsHO.mutate(chance);
this.biasH.mutate(chance);
this.biasO.mutate(chance);
}
public void mutate(double chance, double rangeOfChange) throws Exception{
this.weightsIH.mutate(chance, rangeOfChange);
this.weightsHO.mutate(chance, rangeOfChange);
this.biasH.mutate(chance, rangeOfChange);
this.biasO.mutate(chance, rangeOfChange);
}
}