/
Net.m
138 lines (130 loc) · 4.8 KB
/
Net.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
classdef Net < handle
% Wrapper class of caffe::Net in matlab
properties (Access = private)
hNet_self
attributes
% attribute fields
% hLayer_layers
% hBlob_blobs
% input_blob_indices
% output_blob_indices
% layer_names
% blob_names
end
properties (SetAccess = private)
layer_vec
blob_vec
inputs
outputs
name2layer_index
name2blob_index
layer_names
blob_names
end
methods
function self = Net(varargin)
% decide whether to construct a net from model_file or handle
if ~(nargin == 1 && isstruct(varargin{1}))
% construct a net from model_file
self = caffe.get_net(varargin{:});
return
end
% construct a net from handle
hNet_net = varargin{1};
CHECK(is_valid_handle(hNet_net), 'invalid Net handle');
% setup self handle and attributes
self.hNet_self = hNet_net;
self.attributes = caffe_('net_get_attr', self.hNet_self);
% setup layer_vec
self.layer_vec = caffe.Layer.empty();
for n = 1:length(self.attributes.hLayer_layers)
self.layer_vec(n) = caffe.Layer(self.attributes.hLayer_layers(n));
end
% setup blob_vec
self.blob_vec = caffe.Blob.empty();
for n = 1:length(self.attributes.hBlob_blobs);
self.blob_vec(n) = caffe.Blob(self.attributes.hBlob_blobs(n));
end
% setup input and output blob and their names
% note: add 1 to indices as matlab is 1-indexed while C++ is 0-indexed
self.inputs = ...
self.attributes.blob_names(self.attributes.input_blob_indices + 1);
self.outputs = ...
self.attributes.blob_names(self.attributes.output_blob_indices + 1);
% create map objects to map from name to layers and blobs
self.name2layer_index = containers.Map(self.attributes.layer_names, ...
1:length(self.attributes.layer_names));
self.name2blob_index = containers.Map(self.attributes.blob_names, ...
1:length(self.attributes.blob_names));
% expose layer_names and blob_names for public read access
self.layer_names = self.attributes.layer_names;
self.blob_names = self.attributes.blob_names;
end
function delete (self)
if ~isempty(self.hNet_self)
caffe_('delete_net', self.hNet_self);
end
end
function layer = layers(self, layer_name)
CHECK(ischar(layer_name), 'layer_name must be a string');
layer = self.layer_vec(self.name2layer_index(layer_name));
end
function blob = blobs(self, blob_name)
CHECK(ischar(blob_name), 'blob_name must be a string');
blob = self.blob_vec(self.name2blob_index(blob_name));
end
function blob = params(self, layer_name, blob_index)
CHECK(ischar(layer_name), 'layer_name must be a string');
CHECK(isscalar(blob_index), 'blob_index must be a scalar');
blob = self.layer_vec(self.name2layer_index(layer_name)).params(blob_index);
end
function forward_prefilled(self)
caffe_('net_forward', self.hNet_self);
end
function backward_prefilled(self)
caffe_('net_backward', self.hNet_self);
end
function res = forward(self, input_data)
CHECK(iscell(input_data), 'input_data must be a cell array');
CHECK(length(input_data) == length(self.inputs), ...
'input data cell length must match input blob number');
% copy data to input blobs
for n = 1:length(self.inputs)
self.blobs(self.inputs{n}).set_data(input_data{n});
end
self.forward_prefilled();
% retrieve data from output blobs
res = cell(length(self.outputs), 1);
for n = 1:length(self.outputs)
res{n} = self.blobs(self.outputs{n}).get_data();
end
end
function res = backward(self, output_diff)
CHECK(iscell(output_diff), 'output_diff must be a cell array');
CHECK(length(output_diff) == length(self.outputs), ...
'output diff cell length must match output blob number');
% copy diff to output blobs
for n = 1:length(self.outputs)
self.blobs(self.outputs{n}).set_diff(output_diff{n});
end
self.backward_prefilled();
% retrieve diff from input blobs
res = cell(length(self.inputs), 1);
for n = 1:length(self.inputs)
res{n} = self.blobs(self.inputs{n}).get_diff();
end
end
function copy_from(self, weights_file)
CHECK(ischar(weights_file), 'weights_file must be a string');
CHECK_FILE_EXIST(weights_file);
caffe_('net_copy_from', self.hNet_self, weights_file);
end
function reshape(self)
caffe_('net_reshape', self.hNet_self);
end
function save(self, weights_file)
CHECK(ischar(weights_file), 'weights_file must be a string');
caffe_('net_save', self.hNet_self, weights_file);
end
end
end