Skip to content

Commit 078885f

Browse files
author
Ray Phan
committed
Added in preliminary version of training algorithm and engine
1 parent a756add commit 078885f

File tree

2 files changed

+475
-0
lines changed

2 files changed

+475
-0
lines changed

engine/NeuralNet2.m

Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
classdef NeuralNet2 < handle
2+
3+
properties (Access = public)
4+
LearningRate
5+
ActivationFunction
6+
RegularizationType
7+
RegularizationRate
8+
BatchSize
9+
Debug
10+
end
11+
12+
properties (Access = private)
13+
inputSize
14+
hiddenSizes
15+
outputSize
16+
weights
17+
end
18+
19+
methods
20+
% Class constructor
21+
function this = NeuralNet2(layerSizes)
22+
% default params
23+
this.LearningRate = 0.003;
24+
this.ActivationFunction = 'Tanh';
25+
this.RegularizationType = 'None';
26+
this.RegularizationRate = 0;
27+
this.BatchSize = 10;
28+
this.Debug = false;
29+
30+
% network structure (fully-connected feed-forward)
31+
% Obtain input layer neuron size
32+
this.inputSize = layerSizes(1);
33+
34+
% Obtain the hidden layer neuron sizes
35+
this.hiddenSizes = layerSizes(2:end-1);
36+
37+
% Obtain the output layer neuron size
38+
this.outputSize = layerSizes(end);
39+
40+
% Initialize matrices relating between the ith layer
41+
% and (i+1)th layer
42+
this.weights = cell(1,numel(layerSizes)-1);
43+
for i=1:numel(layerSizes)-1
44+
this.weights{i} = zeros(layerSizes(i)+1, layerSizes(i+1));
45+
end
46+
end
47+
48+
49+
% ???
50+
function configure(this, X, Y)
51+
% check correct sizes
52+
[xrows,xcols] = size(X);
53+
[yrows,ycols] = size(Y);
54+
assert(xrows == yrows);
55+
assert(xcols == this.inputSize);
56+
assert(ycols == this.outputSize);
57+
58+
% min/max of inputs/outputs
59+
inMin = min(X);
60+
inMax = max(X);
61+
outMin = min(Y);
62+
outMax = max(Y);
63+
end
64+
65+
% Initialize neural network weights
66+
function init(this)
67+
% initialize with random weights
68+
for i=1:numel(this.weights)
69+
num = numel(this.weights{i});
70+
this.weights{i}(:) = rand(num,1) - 0.5; % [-0.5,0.5]
71+
end
72+
end
73+
74+
% Perform training with Stochastic Gradient Descent
75+
function perf = train(this, X, Y, numIter)
76+
77+
% Ensure correct sizes
78+
assert(size(X,1) == size(Y,1))
79+
80+
% Ensure regularization rate and batch size is proper
81+
assert(this.BatchSize >= 1);
82+
assert(this.RegularizationRate >= 0);
83+
84+
% Check if we have specified the right regularization type
85+
regType = this.RegularizationType;
86+
assert(any(strcmpi(regType, {'l1', 'l2', 'none'})));
87+
88+
% Initialize cost function array
89+
perf = zeros(1, numIter);
90+
91+
% Initialize weights
92+
init(this);
93+
94+
% Total number of examples
95+
N = size(X,1);
96+
97+
% Total number of applicable layers
98+
L = numel(this.weights);
99+
100+
% Get batch size
101+
B = this.BatchSize;
102+
103+
% Safely catch if batch size is larger than total number
104+
% of examples
105+
if B > N
106+
B = N;
107+
end
108+
109+
% Cell array to store input and outputs of each neuron
110+
sNeuron = cell(1,L);
111+
112+
% First cell array is for the initial
113+
xNeuron = cell(1,L+1);
114+
115+
% Cell array for storing the sensitivities
116+
delta = cell(1,L);
117+
118+
% For L1 regularization
119+
% Method used: http://aclweb.org/anthology/P/P09/P09-1054.pdf
120+
if strcmpi(regType, 'l1')
121+
% This represents the total L1 penalty that each
122+
% weight could have received up to current point
123+
uk = 0;
124+
125+
% Total penalty for each weight that was received up to
126+
% current point
127+
qk = cell(1,L);
128+
for ii=1:L
129+
qk{ii} = zeros(size(this.weights{ii}));
130+
end
131+
end
132+
133+
% Get activation function
134+
fcn = getActivationFunction(this.ActivationFunction);
135+
136+
skipFactor = floor(numIter/10);
137+
138+
% For each iteration...
139+
for ii = 1:numIter
140+
% If the batch size is equal to the total number of examples
141+
% don't bother with random selection as this will be a full
142+
% batch gradient descent
143+
if N == B
144+
ind = 1 : N;
145+
else
146+
% Randomly select examples corresponding to the batch size
147+
% if the batch size is not equal to the number of examples
148+
ind = randperm(N, B);
149+
end
150+
151+
% Select out the training example features and expected outputs
152+
IN = X(ind, :);
153+
OUT = Y(ind, :);
154+
155+
% Initialize input layer
156+
xNeuron{1} = [IN ones(B,1)];
157+
158+
%%% Perform forward propagation
159+
% Make sure you save the inputs and outputs into each neuron
160+
% at the hidden and output layers
161+
for jj = 1:L
162+
% Compute inputs into next layer
163+
sNeuron{jj} = xNeuron{jj} * this.weights{jj};
164+
165+
% Compute outputs of this layer
166+
if jj == L
167+
xNeuron{jj+1} = fcn(sNeuron{jj});
168+
else
169+
xNeuron{jj+1} = [fcn(sNeuron{jj}) ones(B,1)];
170+
end
171+
end
172+
173+
%%% Perform backpropagation
174+
% Get derivative of activation function
175+
dfcn = getDerivativeActivationFunction(this.ActivationFunction);
176+
177+
% Compute sensitivities for output layer
178+
delta{end} = (xNeuron{end} - OUT) .* dfcn(sNeuron{end});
179+
180+
% Compute the sensitivities for the rest of the layers
181+
for jj = L-1 : -1 : 1
182+
delta{jj} = dfcn(sNeuron{jj}) .* ...
183+
(delta{jj+1}*(this.weights{jj+1}(1:end-1,:)).');
184+
end
185+
186+
%%% Compute weight updates
187+
alpha = this.LearningRate;
188+
lambda = this.RegularizationRate;
189+
for jj = 1 : L
190+
% Obtain the outputs and sensitivities for each
191+
% affected layer
192+
XX = xNeuron{jj};
193+
D = delta{jj};
194+
195+
% Calculate batch weight update
196+
weight_update = (1/B)*sum(bsxfun(@times, permute(XX, [2 3 1]), ...
197+
permute(D, [3 2 1])), 3);
198+
199+
% Apply L2 regularization if required
200+
if strcmpi(regType, 'l2')
201+
weight_update(1:end-1,:) = weight_update(1:end-1,:) + ...
202+
(lambda/B)*this.weights{jj}(1:end-1,:);
203+
end
204+
205+
% Compute the final update
206+
this.weights{jj} = this.weights{jj} - alpha*weight_update;
207+
end
208+
209+
% Apply L1 regularization if required
210+
if strcmpi(regType, 'l1')
211+
% Step #1 - Accumulate total L1 penalty that each
212+
% weight could have received up to this point
213+
uk = uk + (alpha*lambda/B);
214+
215+
% Step #2
216+
% Using the updated weights, now apply the penalties
217+
for jj = 1 : L
218+
% 2a - Save previous weights and penalties
219+
% Make sure to remove bias terms
220+
z = this.weights{jj}(1:end-1,:);
221+
q = qk{jj}(1:end-1,:);
222+
223+
% 2b - Using the previous weights, find the weights
224+
% that are positive and negative
225+
w = z;
226+
indwp = w > 0;
227+
indwn = w < 0;
228+
229+
% 2c - Perform the udpate on each condition
230+
% individually
231+
w(indwp) = max(0, w(indwp) - (uk + q(indwp)));
232+
w(indwn) = min(0, w(indwn) + (uk - q(indwn)));
233+
234+
% 2d - Update the actual penalties
235+
qk{jj}(1:end-1,:) = q + (w - z);
236+
237+
% Don't forget to update the actual weights!
238+
this.weights{jj}(1:end-1,:) = w;
239+
end
240+
end
241+
242+
% Compute cost at this iteration
243+
perf(ii) = (0.5/B)*sum(sum((xNeuron{end} - OUT).^2));
244+
245+
% Add in regularization if necessary
246+
if strcmpi(regType, 'l1')
247+
for jj = 1 : L
248+
perf(ii) = perf(ii) + ...
249+
(lambda/B)*sum(sum(abs(this.weights{jj}(1:end-1,:))));
250+
end
251+
elseif strcmpi(regType, 'l2')
252+
for jj = 1 : L
253+
perf(ii) = perf(ii) + ...
254+
(0.5*lambda/B)*sum(sum((this.weights{jj}(1:end-1,:)).^2));
255+
end
256+
end
257+
258+
% Debugging output
259+
if this.Debug
260+
if mod(ii,skipFactor) == 1 || ii == numIter
261+
fprintf('Iteration #%d - Cost: %4.6e\n', ii, perf(ii));
262+
end
263+
end
264+
end
265+
end
266+
267+
% Perform forward propagation
268+
% Note that the bias units are the last row of the matrix
269+
% Inputs are in a 2D matrix of N x M
270+
% N is the number of examples
271+
% M is the number of features / number of input neurons
272+
function OUT = sim(this, X)
273+
% Check if the total number of features matches the
274+
% total number of input neurons
275+
assert(size(X,2) == this.inputSize);
276+
277+
% Get total number of examples
278+
N = size(X,1);
279+
280+
%%% Begin algorithm
281+
% Start with first layer
282+
IN = X;
283+
284+
% Get activation function
285+
fcn = getActivationFunction(this.ActivationFunction);
286+
287+
% For each layer...
288+
for ii=1:numel(this.weights)
289+
% Compute inputs into each neuron and corresponding
290+
% outputs
291+
OUT = fcn([IN ones(N,1)] * this.weights{ii});
292+
293+
% Save for next iteration
294+
IN = OUT;
295+
end
296+
end
297+
end
298+
299+
end
300+
301+
function fcn = getActivationFunction(activation)
302+
switch lower(activation)
303+
case 'linear'
304+
fcn = @f_linear;
305+
case 'relu'
306+
fcn = @f_relu;
307+
case 'tanh'
308+
fcn = @f_tanh;
309+
case 'sigmoid'
310+
fcn = @f_sigmoid;
311+
otherwise
312+
error('Unknown activation function');
313+
end
314+
end
315+
316+
function fcn = getDerivativeActivationFunction(activation)
317+
switch lower(activation)
318+
case 'linear'
319+
fcn = @fd_linear;
320+
case 'relu'
321+
fcn = @fd_relu;
322+
case 'tanh'
323+
fcn = @fd_tanh;
324+
case 'sigmoid'
325+
fcn = @fd_sigmoid;
326+
otherwise
327+
error('Unknown activation function');
328+
end
329+
end
330+
331+
% activation funtions and their derivatives
332+
function y = f_linear(x)
333+
% See also: purelin
334+
y = x;
335+
end
336+
337+
function y = fd_linear(x)
338+
% See also: dpurelin
339+
y = ones(size(x));
340+
end
341+
342+
function y = f_relu(x)
343+
% See also: poslin
344+
y = max(x, 0);
345+
end
346+
347+
function y = fd_relu(x)
348+
% See also: dposlin
349+
y = double(x >= 0);
350+
end
351+
352+
function y = f_tanh(x)
353+
% See also: tansig
354+
%y = 2 ./ (1 + exp(-2*x)) - 1;
355+
y = tanh(x);
356+
end
357+
358+
function y = fd_tanh(x)
359+
% See also: dtansig
360+
y = f_tanh(x);
361+
y = 1 - y.^2;
362+
end
363+
364+
function y = f_sigmoid(x)
365+
% See also: logsig
366+
y = 1 ./ (1 + exp(-x));
367+
end
368+
369+
function y = fd_sigmoid(x)
370+
% See also: dlogsig
371+
y = f_sigmoid(x);
372+
y = y .* (1 - y);
373+
end

0 commit comments

Comments
 (0)