forked from braeunlich/anagnostes
/
LeNet.java
110 lines (98 loc) · 4.39 KB
/
LeNet.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package com.bsiag.anagnostes.server.neuralnetwork;
import java.io.IOException;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class LeNet {
public static final int NUM_OUTPUTS = 10;
public static final int BATCH_SIZE = 64;
private static final int NUM_CHANNELS = 1;
private static final int NUM_ITERATIONS = 1;
private static final int SEED = 123;
/**
* Regarding the .setInputType(InputType.convolutionalFlat(28,28,1)) line: This does a few things. (a) It adds
* preprocessors, which handle things like the transition between the convolutional/subsampling layers and the dense
* layer (b) Does some additional configuration validation (c) Where necessary, sets the nIn (number of input
* neurons, or input depth in the case of CNNs) values for each layer based on the size of the previous layer (but
* it won't override values manually set by the user) InputTypes can be used with other layer types too (RNNs, MLPs
* etc) not just CNNs. For normal images (when using ImageRecordReader) use
* InputType.convolutional(height,width,depth). MNIST record reader is a special case, that outputs 28x28 pixel
* grayscale (nChannels=1) images, in a "flattened" row vector format (i.e., 1x784 vectors), hence the
* "convolutionalFlat" input type used here.
*/
public static MultiLayerConfiguration networkConfiguration() {
return new NeuralNetConfiguration.Builder()
.seed(SEED).weightInit(WeightInit.XAVIER)
.iterations(NUM_ITERATIONS)
.regularization(true).l2(0.0005).learningRate(.01)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.NESTEROVS).momentum(0.9)
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
.stride(1, 1)
.nIn(NUM_CHANNELS)
.nOut(20)
.activation(Activation.IDENTITY)
.build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(2, new ConvolutionLayer.Builder(5, 5).stride(1, 1)
.nOut(50)
.activation(Activation.IDENTITY)
.build())
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(4, new DenseLayer.Builder()
.activation(Activation.RELU)
.nOut(500)
.build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nOut(NUM_OUTPUTS)
.build())
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.backprop(true)
.pretrain(false).build();
}
public static DataSetIterator mnistTrainSetIterator() {
try {
return new MnistDataSetIterator(BATCH_SIZE, true, 12345);
} catch (IOException e) {
throw new RuntimeException("Couldn't build the MnistDataSetIterator");
}
}
public static DataSetIterator mnistTestSetIterator() {
try {
return new MnistDataSetIterator(BATCH_SIZE, false, 12345);
} catch (IOException e) {
throw new RuntimeException("Couldn't build the MnistDataSetIterator");
}
}
public static DataSetIterator numbersTrainSetIterator(String numbersBaseFolder) {
return new NumbersDatasetIterator(BATCH_SIZE, numbersBaseFolder, true);
}
public static DataSetIterator numbersTrainSetIterator() {
return new NumbersDatasetIterator(BATCH_SIZE, true);
}
public static DataSetIterator numbersTestSetIterator(String numbersBaseFolder) {
return new NumbersDatasetIterator(BATCH_SIZE, numbersBaseFolder, false);
}
public static DataSetIterator numbersTestSetIterator() {
return new NumbersDatasetIterator(BATCH_SIZE, false);
}
}