Skip to content

Commit 178da57

Browse files
committed
Changed Backpropagation to use Java for performance reasons
1 parent 459fd5e commit 178da57

File tree

7 files changed

+119
-133
lines changed

7 files changed

+119
-133
lines changed

src/main/groovy/net/zomis/machlearn/images/ImageNetworkBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package net.zomis.machlearn.images;
22

3-
import net.zomis.machlearn.neural.BackPropagation;
3+
import net.zomis.machlearn.neural.Backpropagation;
44
import net.zomis.machlearn.neural.LearningData;
55
import net.zomis.machlearn.neural.NeuralNetwork;
66
import net.zomis.machlearn.neural.NeuronLayer;
@@ -36,7 +36,7 @@ public ImageNetworkBuilder classify(Object result, double[] input) {
3636
return this;
3737
}
3838

39-
public ImageNetwork learn(BackPropagation backprop, Random random) {
39+
public ImageNetwork learn(Backpropagation backprop, Random random) {
4040
int outputNodes = classifications.size() - 1;
4141

4242
NeuronLayer parentLayer = this.network.getLastLayer();

src/main/groovy/net/zomis/machlearn/neural/BackPropagation.groovy renamed to src/main/groovy/net/zomis/machlearn/neural/Backpropagation.java

Lines changed: 108 additions & 126 deletions
Large diffs are not rendered by default.

src/main/groovy/net/zomis/machlearn/neural/LearningData.groovy

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,8 @@ class LearningData {
1313
this.outputs = outputs
1414
}
1515

16+
double getInput(int i) {
17+
return inputs[i];
18+
}
19+
1620
}

src/main/groovy/net/zomis/machlearn/neural/NeuralMain.groovy

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class NeuralMain {
2222
examples << new LearningData([0, 1] as double[], [0, 1] as double[])
2323
examples << new LearningData([1, 0] as double[], [0, 1] as double[])
2424
examples << new LearningData([1, 1] as double[], [1, 1] as double[])
25-
new BackPropagation(0.2, 100000).backPropagationLearning(examples, network)
25+
new Backpropagation(0.2, 100000).backPropagationLearning(examples, network)
2626

2727
network.printAll()
2828

src/test/groovy/net/zomis/machlearn/images/Screenshoter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package net.zomis.machlearn.images;
22

3-
import net.zomis.machlearn.neural.BackPropagation;
3+
import net.zomis.machlearn.neural.Backpropagation;
44
import net.zomis.machlearn.neural.Neuron;
55

66
import javax.imageio.ImageIO;
@@ -32,7 +32,7 @@ public static void main(String[] args) throws AWTException, IOException {
3232
.classify("clicked", analyze.imagePart(image, 790, 241))
3333
.classify("flag", analyze.imagePart(image, 790, 197))
3434
.classifyNone(analyze.imagePart(image, 0, 0))
35-
.learn(new BackPropagation(0.1, 100), new Random(42));
35+
.learn(new Backpropagation(0.1, 100), new Random(42));
3636
// SlidingWindowResult points = analyze.slidingWindow(network, image).scaleX(25, 60).step(4).overlapping(false).run();
3737
// network.getNetwork().printAll();
3838

src/test/groovy/net/zomis/machlearn/neural/BasicPropagationTest.groovy

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class BasicPropagationTest {
5757

5858
@Test
5959
void learn() {
60-
new BackPropagation(0.2, 100000).backPropagationLearning(examples, network)
60+
new Backpropagation(0.2, 100000).backPropagationLearning(examples, network)
6161

6262
network.printAll()
6363

src/test/groovy/net/zomis/machlearn/neural/LoadSaveTest.groovy

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class LoadSaveTest {
2525

2626
def outputLayer = network.createLayer('OUT')
2727
outputLayer.createNeuron().addInputs(middleLayer)
28-
new BackPropagation(0.2, 100000).backPropagationLearning(examples, network, new Random(42))
28+
new Backpropagation(0.2, 100000).backPropagationLearning(examples, network, new Random(42))
2929

3030
def savedNetwork = new ByteArrayOutputStream()
3131
network.save(savedNetwork)

0 commit comments

Comments
 (0)