Skip to content

Commit 6613411

Browse files
committed
Significantly increased speed by converting from Groovy to Java
1 parent dede075 commit 6613411

17 files changed

+408
-341
lines changed

src/main/groovy/net/zomis/machlearn/neural/Backpropagation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ public NeuralNetwork backPropagationLearning(Collection<LearningData> examples,
8989
for (int nodei = 0; nodei < layer.size(); nodei++) {
9090
Neuron neuron = layer.getNeurons().get(nodei);
9191
double sum = neuron.getOutputs().stream().mapToDouble(link ->
92-
link.getWeight() * deltas[layerIdx][link.getTo().getIndexInLayer()]
92+
link.getWeight() * deltas[layerIdx][link.getTo().indexInLayer]
9393
).sum();
9494
double gPrim = neuron.getOutput() * (1 - neuron.getOutput());
9595
double delta = sum * gPrim;

src/main/groovy/net/zomis/machlearn/neural/DummyConnection.groovy

Lines changed: 0 additions & 32 deletions
This file was deleted.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package net.zomis.machlearn.neural;
2+
3+
class DummyConnection implements NeuronLink {
4+
5+
private double weight = 1;
6+
7+
@Override
8+
public double calculateInput() {
9+
return getInputValue() * this.weight;
10+
}
11+
12+
@Override
13+
public double getInputValue() {
14+
return 1;
15+
}
16+
17+
@Override
18+
public double getWeight() {
19+
return this.weight;
20+
}
21+
22+
@Override
23+
public void setWeight(double value) {
24+
this.weight = value;
25+
}
26+
27+
@Override
28+
public String toString() {
29+
return "w0 $weight";
30+
}
31+
32+
}

src/main/groovy/net/zomis/machlearn/neural/LearningData.groovy

Lines changed: 0 additions & 20 deletions
This file was deleted.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package net.zomis.machlearn.neural;
2+
3+
import java.util.Arrays;
4+
5+
public class LearningData {
6+
7+
public final double[] inputs;
8+
public final double[] outputs;
9+
10+
public LearningData(double[] inputs, double[] outputs) {
11+
this.inputs = inputs;
12+
this.outputs = outputs;
13+
}
14+
15+
public double getInput(int i) {
16+
return inputs[i];
17+
}
18+
19+
@Override
20+
public String toString() {
21+
return "LearningData{" +
22+
"inputs=" + Arrays.toString(inputs) +
23+
", outputs=" + Arrays.toString(outputs) +
24+
'}';
25+
}
26+
27+
public double[] getOutputs() {
28+
return outputs;
29+
}
30+
31+
public double[] getInputs() {
32+
return inputs;
33+
}
34+
}

src/main/groovy/net/zomis/machlearn/neural/NeuralNetwork.groovy

Lines changed: 0 additions & 126 deletions
This file was deleted.
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
package net.zomis.machlearn.neural;
2+
3+
import java.io.*;
4+
import java.util.ArrayList;
5+
import java.util.List;
6+
import java.util.stream.Stream;
7+
8+
public class NeuralNetwork {
9+
10+
public List<NeuronLayer> layers = new ArrayList<>();
11+
12+
public List<NeuronLayer> getLayers() {
13+
return layers;
14+
}
15+
16+
public NeuronLayer getInputLayer() {
17+
return getLayer(0);
18+
}
19+
20+
public NeuronLayer getOutputLayer() {
21+
return getLayer(layers.size() - 1);
22+
}
23+
24+
public NeuronLayer getLayer(int layerIndex) {
25+
return layers.get(layerIndex);
26+
}
27+
28+
public NeuronLayer createLayer(String name) {
29+
NeuronLayer layer = new NeuronLayer(name);
30+
this.layers.add(layer);
31+
return layer;
32+
}
33+
34+
public int getLayerCount() {
35+
return layers.size();
36+
}
37+
38+
public Stream<NeuronLink> links() {
39+
return this.layers.stream()
40+
.skip(1)
41+
.flatMap(it -> it.neurons.stream())
42+
.flatMap(it -> it.inputs.stream());
43+
}
44+
45+
public void printAll() {
46+
System.out.println("$layerCount layers:");
47+
layers.stream().forEach(it -> {
48+
it.printNodes();
49+
System.out.println();
50+
});
51+
System.out.println();
52+
}
53+
54+
void save(OutputStream output) {
55+
try (DataOutputStream it = new DataOutputStream(output)) {
56+
it.writeInt(this.getLayerCount());
57+
for (NeuronLayer layer : layers) {
58+
it.writeInt(layer.size());
59+
it.writeInt(layer.name.length());
60+
it.writeBytes(layer.name);
61+
}
62+
for (NeuronLayer layer : layers) {
63+
for (Neuron neuron : layer) {
64+
for (NeuronLink link : neuron.inputs) {
65+
it.writeDouble(link.getWeight());
66+
}
67+
}
68+
}
69+
} catch (IOException e) {
70+
throw new RuntimeException(e);
71+
}
72+
}
73+
74+
static NeuralNetwork load(InputStream input) {
75+
NeuralNetwork network = new NeuralNetwork();
76+
try (DataInputStream it = new DataInputStream(input)) {
77+
int layers = it.readInt();
78+
for (int i = 0; i < layers; i++) {
79+
int size = it.readInt();
80+
int nameLength = it.readInt();
81+
StringBuilder name = new StringBuilder();
82+
for (int nameIndex = 0; nameIndex < nameLength; nameIndex++) {
83+
name.append((char) it.readByte());
84+
}
85+
NeuronLayer layer = network.createLayer(name.toString());
86+
for (int j = 0; j < size; j++) {
87+
layer.createNeuron();
88+
}
89+
}
90+
for (int i = 0; i < layers; i++) {
91+
NeuronLayer layer = network.getLayer(i);
92+
if (i > 0) {
93+
final int ii = i;
94+
layer.neurons.forEach(it2 -> it2.addInputs(network.getLayer(ii - 1)));
95+
}
96+
for (Neuron neuron : layer) {
97+
for (NeuronLink link : neuron.inputs) {
98+
link.setWeight(it.readDouble());
99+
}
100+
}
101+
}
102+
} catch (IOException e) {
103+
throw new RuntimeException(e);
104+
}
105+
return network;
106+
}
107+
108+
public double[] run(double[] input) {
109+
double[] output = new double[getOutputLayer().size()];
110+
assert input.length == getInputLayer().size();
111+
for (int i = 0; i < getInputLayer().size(); i++) {
112+
getInputLayer().neurons.get(i).output = input[i];
113+
}
114+
115+
int layerIndex = 0;
116+
for (NeuronLayer layer : layers) {
117+
if (layerIndex++ == 0) {
118+
// Do not process input layer
119+
continue;
120+
}
121+
for (Neuron node : layer) {
122+
node.process();
123+
}
124+
}
125+
for (int i = 0; i < getOutputLayer().size(); i++) {
126+
output[i] = getOutputLayer().neurons.get(i).output;
127+
}
128+
/* for (int i = 0; i < inputLayer.size(); i++) {
129+
inputLayer.getNeurons().get(i).output = 1
130+
}*/
131+
132+
return output;
133+
}
134+
135+
public NeuronLayer getLastLayer() {
136+
return getLayer(layers.size() - 1);
137+
}
138+
139+
}

0 commit comments

Comments
 (0)