|
|
@@ -25,11 +25,11 @@ def simple_net_file(num_output): |
|
|
bias_filler { type: 'constant' value: 2 } }
|
|
|
param { decay_mult: 1 } param { decay_mult: 0 }
|
|
|
}
|
|
|
- layer { type: 'InnerProduct' name: 'ip' bottom: 'conv' top: 'ip'
|
|
|
+ layer { type: 'InnerProduct' name: 'ip' bottom: 'conv' top: 'ip_blob'
|
|
|
inner_product_param { num_output: """ + str(num_output) + """
|
|
|
weight_filler { type: 'gaussian' std: 2.5 }
|
|
|
bias_filler { type: 'constant' value: -3 } } }
|
|
|
- layer { type: 'SoftmaxWithLoss' name: 'loss' bottom: 'ip' bottom: 'label'
|
|
|
+ layer { type: 'SoftmaxWithLoss' name: 'loss' bottom: 'ip_blob' bottom: 'label'
|
|
|
top: 'loss' }""")
|
|
|
f.close()
|
|
|
return f.name
|
|
|
@@ -71,6 +71,43 @@ def test_forward_backward(self): |
|
|
self.net.forward()
|
|
|
self.net.backward()
|
|
|
|
|
|
+ def test_forward_start_end(self):
|
|
|
+ conv_blob=self.net.blobs['conv'];
|
|
|
+ ip_blob=self.net.blobs['ip_blob'];
|
|
|
+ sample_data=np.random.uniform(size=conv_blob.data.shape);
|
|
|
+ sample_data=sample_data.astype(np.float32);
|
|
|
+ conv_blob.data[:]=sample_data;
|
|
|
+ forward_blob=self.net.forward(start='ip',end='ip');
|
|
|
+ self.assertIn('ip_blob',forward_blob);
|
|
|
+
|
|
|
+ manual_forward=[];
|
|
|
+ for i in range(0,conv_blob.data.shape[0]):
|
|
|
+ dot=np.dot(self.net.params['ip'][0].data,
|
|
|
+ conv_blob.data[i].reshape(-1));
|
|
|
+ manual_forward.append(dot+self.net.params['ip'][1].data);
|
|
|
+ manual_forward=np.array(manual_forward);
|
|
|
+
|
|
|
+ np.testing.assert_allclose(ip_blob.data,manual_forward,rtol=1e-3);
|
|
|
+
|
|
|
+ def test_backward_start_end(self):
|
|
|
+ conv_blob=self.net.blobs['conv'];
|
|
|
+ ip_blob=self.net.blobs['ip_blob'];
|
|
|
+ sample_data=np.random.uniform(size=ip_blob.data.shape)
|
|
|
+ sample_data=sample_data.astype(np.float32);
|
|
|
+ ip_blob.diff[:]=sample_data;
|
|
|
+ backward_blob=self.net.backward(start='ip',end='ip');
|
|
|
+ self.assertIn('conv',backward_blob);
|
|
|
+
|
|
|
+ manual_backward=[];
|
|
|
+ for i in range(0,conv_blob.data.shape[0]):
|
|
|
+ dot=np.dot(self.net.params['ip'][0].data.transpose(),
|
|
|
+ sample_data[i].reshape(-1));
|
|
|
+ manual_backward.append(dot);
|
|
|
+ manual_backward=np.array(manual_backward);
|
|
|
+ manual_backward=manual_backward.reshape(conv_blob.data.shape);
|
|
|
+
|
|
|
+ np.testing.assert_allclose(conv_blob.diff,manual_backward,rtol=1e-3);
|
|
|
+
|
|
|
def test_clear_param_diffs(self):
|
|
|
# Run a forward/backward step to have non-zero diffs
|
|
|
self.net.forward()
|
|
|
@@ -90,13 +127,13 @@ def test_top_bottom_names(self): |
|
|
self.assertEqual(self.net.top_names,
|
|
|
OrderedDict([('data', ['data', 'label']),
|
|
|
('conv', ['conv']),
|
|
|
- ('ip', ['ip']),
|
|
|
+ ('ip', ['ip_blob']),
|
|
|
('loss', ['loss'])]))
|
|
|
self.assertEqual(self.net.bottom_names,
|
|
|
OrderedDict([('data', []),
|
|
|
('conv', ['data']),
|
|
|
('ip', ['conv']),
|
|
|
- ('loss', ['ip', 'label'])]))
|
|
|
+ ('loss', ['ip_blob', 'label'])]))
|
|
|
|
|
|
def test_save_and_read(self):
|
|
|
f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
|
|
|
|
0 comments on commit
2e33792