/
mnist_example.m
85 lines (72 loc) · 2.14 KB
/
mnist_example.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
load mnist_uint8;
train_x = double(train_x) / 255;
train_y = double(train_y);
test_x = gpuArray(double(test_x) / 255);
test_y = gpuArray(double(test_y));
% part = zeros(10, 2);
% for i = 1:10
% [~, col, ~] = find(train_y);
% col = find(col == i);
% part(i, :) = [col(1) col(end)];
% end
rng('default');
% model parameters
sizes = [500 500 2000];
opts.numepochs = 10;
opts.batchsize = 10;
opts.momentum = 0;
opts.alpha = 0.1;
opts.decay = 0.00001;
opts.k = 1;
% opts.part = part;
dbn = DBN(train_x, train_y, sizes, opts);
tic
train(dbn, train_x, train_y);
toc
%% compute most probable class label given test data
probs = dbn.predict(test_x, test_y);
[~, I] = max(probs, [], 2);
pred = bsxfun(@eq, I, 1:10);
mis = find(~all(pred == test_y,2));
err = length(mis) / size(test_y, 1);
fprintf('Classification error is %3.2f%%\n',err*100);
%% plot MNIST examples
figure('Color','black');
[row,col,~] = find(train_y);
for i = 1:4
for j = 1:10
ix = (i-1)*10 + j;
idx = row(col == j);
subplot(4,10,ix), imshow(reshape(train_x(idx(ix),:), 28, 28)');
end
end
%% plot some misclassified test cases
figure('Color','black');
idx = mis(1:10:100);
for i = 1:10
subplot(2,5,i), imshow(reshape(test_x(idx(i),:), 28, 28)');
[~, predicted] = max(probs(idx(i),:));
[~, actual] = max(test_y(idx(i),:));
t = title(sprintf('Predicted %d\nActual %d', predicted - 1, actual - 1), 'Color', 'white');
set(t, 'horizontalAlignment', 'left');
set(t, 'units', 'normalized');
h1 = get(t, 'position');
set(t, 'position', [0 1 0]);
end
%% plot samples as iterations of gibbs sampling increases
figure('Color','black');
gibbSteps = [1, 10, 100, 1000];
for i = 1:10
for j = 1:length(gibbSteps)
subplot(length(gibbSteps),10,(j-1)*10+i), imshow(reshape(dbn.generate2(i, 10, gibbSteps(j)), 28, 28)');
end
end
%% visualize the weights of the first layer
figure('Color','black');
for i = 1:100
subplot(10,10,i), imshow(reshape(dbn.rbm(1).W(i,:), 28, 28)', [-1, 1]);
end
%% save a sequence of samples generated by each step of gibbs sampling
for i = 1:10
imageseq(dbn, i, 10, 200);
end