-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
MNISTClassifier.cs
178 lines (154 loc) · 9.66 KB
/
MNISTClassifier.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
using System.Collections.Generic;
using System.IO;
using System.Linq;
namespace CNTK.CSTrainingExamples
{
/// <summary>
/// This class shows how to build and train a classifier for handwritting data (MNIST).
/// For more details, please follow a serial of tutorials below:
/// https://github.com/Microsoft/CNTK/blob/master/Tutorials/CNTK_103A_MNIST_DataLoader.ipynb
/// https://github.com/Microsoft/CNTK/blob/master/Tutorials/CNTK_103B_MNIST_LogisticRegression.ipynb
/// https://github.com/Microsoft/CNTK/blob/master/Tutorials/CNTK_103C_MNIST_MultiLayerPerceptron.ipynb
/// https://github.com/Microsoft/CNTK/blob/master/Tutorials/CNTK_103D_MNIST_ConvolutionalNeuralNetwork.ipynb
/// </summary>
public class MNISTClassifier
{
/// <summary>
/// Execution folder is: CNTK/x64/BuildFolder
/// Data folder is: CNTK/Tests/EndToEndTests/Image/Data
/// </summary>
public static string ImageDataFolder = "../../Tests/EndToEndTests/Image/Data";
/// <summary>
/// Train and evaluate a image classifier for MNIST data.
/// </summary>
/// <param name="device">CPU or GPU device to run training and evaluation</param>
/// <param name="useConvolution">option to use convolution network or to use multilayer perceptron</param>
/// <param name="forceRetrain">whether to override an existing model.
/// if true, any existing model will be overridden and the new one evaluated.
/// if false and there is an existing model, the existing model is evaluated.</param>
public static void TrainAndEvaluate(DeviceDescriptor device, bool useConvolution, bool forceRetrain)
{
var featureStreamName = "features";
var labelsStreamName = "labels";
var classifierName = "classifierOutput";
Function classifierOutput;
int[] imageDim = useConvolution ? new int[] { 28, 28, 1 } : new int[] { 784 };
int imageSize = 28 * 28;
int numClasses = 10;
IList<StreamConfiguration> streamConfigurations = new StreamConfiguration[]
{ new StreamConfiguration(featureStreamName, imageSize), new StreamConfiguration(labelsStreamName, numClasses) };
string modelFile = useConvolution ? "MNISTConvolution.model" : "MNISTMLP.model";
// If a model already exists and not set to force retrain, validate the model and return.
if (File.Exists(modelFile) && !forceRetrain)
{
var minibatchSourceExistModel = MinibatchSource.TextFormatMinibatchSource(
Path.Combine(ImageDataFolder, "Test_cntk_text.txt"), streamConfigurations);
TestHelper.ValidateModelWithMinibatchSource(modelFile, minibatchSourceExistModel,
imageDim, numClasses, featureStreamName, labelsStreamName, classifierName, device);
return;
}
// build the network
var input = CNTKLib.InputVariable(imageDim, DataType.Float, featureStreamName);
if (useConvolution)
{
var scaledInput = CNTKLib.ElementTimes(Constant.Scalar<float>(0.00390625f, device), input);
classifierOutput = CreateConvolutionalNeuralNetwork(scaledInput, numClasses, device, classifierName);
}
else
{
// For MLP, we like to have the middle layer to have certain amount of states.
int hiddenLayerDim = 200;
var scaledInput = CNTKLib.ElementTimes(Constant.Scalar<float>(0.00390625f, device), input);
classifierOutput = CreateMLPClassifier(device, numClasses, hiddenLayerDim, scaledInput, classifierName);
}
var labels = CNTKLib.InputVariable(new int[] { numClasses }, DataType.Float, labelsStreamName);
var trainingLoss = CNTKLib.CrossEntropyWithSoftmax(new Variable(classifierOutput), labels, "lossFunction");
var prediction = CNTKLib.ClassificationError(new Variable(classifierOutput), labels, "classificationError");
// prepare training data
var minibatchSource = MinibatchSource.TextFormatMinibatchSource(
Path.Combine(ImageDataFolder, "Train_cntk_text.txt"), streamConfigurations, MinibatchSource.InfinitelyRepeat);
var featureStreamInfo = minibatchSource.StreamInfo(featureStreamName);
var labelStreamInfo = minibatchSource.StreamInfo(labelsStreamName);
// set per sample learning rate
CNTK.TrainingParameterScheduleDouble learningRatePerSample = new CNTK.TrainingParameterScheduleDouble(
0.003125, 1);
IList<Learner> parameterLearners = new List<Learner>() { Learner.SGDLearner(classifierOutput.Parameters(), learningRatePerSample) };
var trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction, parameterLearners);
//
const uint minibatchSize = 64;
int outputFrequencyInMinibatches = 20, i = 0;
int epochs = 5;
while (epochs > 0)
{
var minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device);
var arguments = new Dictionary<Variable, MinibatchData>
{
{ input, minibatchData[featureStreamInfo] },
{ labels, minibatchData[labelStreamInfo] }
};
trainer.TrainMinibatch(arguments, device);
TestHelper.PrintTrainingProgress(trainer, i++, outputFrequencyInMinibatches);
// MinibatchSource is created with MinibatchSource.InfinitelyRepeat.
// Batching will not end. Each time minibatchSource completes an sweep (epoch),
// the last minibatch data will be marked as end of a sweep. We use this flag
// to count number of epochs.
if (TestHelper.MiniBatchDataIsSweepEnd(minibatchData.Values))
{
epochs--;
}
}
// save the trained model
classifierOutput.Save(modelFile);
// validate the model
var minibatchSourceNewModel = MinibatchSource.TextFormatMinibatchSource(
Path.Combine(ImageDataFolder, "Test_cntk_text.txt"), streamConfigurations, MinibatchSource.FullDataSweep);
TestHelper.ValidateModelWithMinibatchSource(modelFile, minibatchSourceNewModel,
imageDim, numClasses, featureStreamName, labelsStreamName, classifierName, device);
}
private static Function CreateMLPClassifier(DeviceDescriptor device, int numOutputClasses, int hiddenLayerDim,
Function scaledInput, string classifierName)
{
Function dense1 = TestHelper.Dense(scaledInput, hiddenLayerDim, device, Activation.Sigmoid, "");
Function classifierOutput = TestHelper.Dense(dense1, numOutputClasses, device, Activation.None, classifierName);
return classifierOutput;
}
/// <summary>
/// Create convolution neural network
/// </summary>
/// <param name="features">input feature variable</param>
/// <param name="outDims">number of output classes</param>
/// <param name="device">CPU or GPU device to run the model</param>
/// <param name="classifierName">name of the classifier</param>
/// <returns>the convolution neural network classifier</returns>
static Function CreateConvolutionalNeuralNetwork(Variable features, int outDims, DeviceDescriptor device, string classifierName)
{
// 28x28x1 -> 14x14x4
int kernelWidth1 = 3, kernelHeight1 = 3, numInputChannels1 = 1, outFeatureMapCount1 = 4;
int hStride1 = 2, vStride1 = 2;
int poolingWindowWidth1 = 3, poolingWindowHeight1 = 3;
Function pooling1 = ConvolutionWithMaxPooling(features, device, kernelWidth1, kernelHeight1,
numInputChannels1, outFeatureMapCount1, hStride1, vStride1, poolingWindowWidth1, poolingWindowHeight1);
// 14x14x4 -> 7x7x8
int kernelWidth2 = 3, kernelHeight2 = 3, numInputChannels2 = outFeatureMapCount1, outFeatureMapCount2 = 8;
int hStride2 = 2, vStride2 = 2;
int poolingWindowWidth2 = 3, poolingWindowHeight2 = 3;
Function pooling2 = ConvolutionWithMaxPooling(pooling1, device, kernelWidth2, kernelHeight2,
numInputChannels2, outFeatureMapCount2, hStride2, vStride2, poolingWindowWidth2, poolingWindowHeight2);
Function denseLayer = TestHelper.Dense(pooling2, outDims, device, Activation.None, classifierName);
return denseLayer;
}
private static Function ConvolutionWithMaxPooling(Variable features, DeviceDescriptor device,
int kernelWidth, int kernelHeight, int numInputChannels, int outFeatureMapCount,
int hStride, int vStride, int poolingWindowWidth, int poolingWindowHeight)
{
// parameter initialization hyper parameter
double convWScale = 0.26;
var convParams = new Parameter(new int[] { kernelWidth, kernelHeight, numInputChannels, outFeatureMapCount }, DataType.Float,
CNTKLib.GlorotUniformInitializer(convWScale, -1, 2), device);
Function convFunction = CNTKLib.ReLU(CNTKLib.Convolution(convParams, features, new int[] { 1, 1, numInputChannels } /* strides */));
Function pooling = CNTKLib.Pooling(convFunction, PoolingType.Max,
new int[] { poolingWindowWidth, poolingWindowHeight }, new int[] { hStride, vStride }, new bool[] { true });
return pooling;
}
}
}