-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for loading and saving Neural Network
- Loading branch information
Showing
3 changed files
with
104 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 56 additions & 0 deletions
56
src/test/groovy/net/zomis/machlearn/neural/LoadSaveTest.groovy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
package net.zomis.machlearn.neural | ||
|
||
import org.junit.Test | ||
|
||
class LoadSaveTest { | ||
|
||
@Test | ||
void simpleLoadSave() { | ||
List<LearningData> examples = Arrays.asList( | ||
new LearningData([0, 0] as double[], [0] as double[]), | ||
new LearningData([0, 1] as double[], [1] as double[]), | ||
new LearningData([1, 0] as double[], [1] as double[]), | ||
new LearningData([1, 1] as double[], [0] as double[]), | ||
) | ||
def network = new NeuralNetwork() | ||
def inputLayer = network.createLayer('IN') | ||
inputLayer.createNeuron() | ||
inputLayer.createNeuron() | ||
|
||
def middleLayer = network.createLayer('MIDDLE') | ||
middleLayer.createNeuron().addInputs(inputLayer) | ||
middleLayer.createNeuron().addInputs(inputLayer) | ||
|
||
def outputLayer = network.createLayer('OUT') | ||
outputLayer.createNeuron().addInputs(middleLayer) | ||
new BackPropagation(0.2, 100000).backPropagationLearning(examples, network, new Random(42)) | ||
|
||
def savedNetwork = new ByteArrayOutputStream() | ||
network.save(savedNetwork) | ||
|
||
assert savedNetwork.toByteArray() == LoadSaveTest.class.getClassLoader() | ||
.getResource('simplenetwork.network').bytes | ||
} | ||
|
||
@Test | ||
void loadTest() { | ||
def network = NeuralNetwork.load(LoadSaveTest.class.getClassLoader() | ||
.getResourceAsStream('simplenetwork.network')) | ||
assert network.layerCount == 3 | ||
assert network.getLayer(0).size() == 2 | ||
assert network.getLayer(1).size() == 2 | ||
assert network.getLayer(2).size() == 1 | ||
|
||
assert network.getLayer(0).name == 'IN' | ||
assert network.getLayer(1).name == 'MIDDLE' | ||
assert network.getLayer(2).name == 'OUT' | ||
def loadedNetwork = new ByteArrayOutputStream() | ||
network.save(loadedNetwork) | ||
|
||
def runResult = network.run([1, 0] as double[]) | ||
println runResult | ||
assert loadedNetwork.toByteArray() == LoadSaveTest.class.getClassLoader() | ||
.getResource('simplenetwork.network').bytes | ||
} | ||
|
||
} |
Binary file not shown.