|
|
@@ -28,6 +28,21 @@ class ExceptionLayer(caffe.Layer): |
|
|
def setup(self, bottom, top):
|
|
|
raise RuntimeError
|
|
|
|
|
|
+class ParameterLayer(caffe.Layer):
|
|
|
+ """A layer that just multiplies by ten"""
|
|
|
+
|
|
|
+ def setup(self, bottom, top):
|
|
|
+ self.blobs.add_blob(1)
|
|
|
+ self.blobs[0].data[0] = 0
|
|
|
+
|
|
|
+ def reshape(self, bottom, top):
|
|
|
+ top[0].reshape(*bottom[0].data.shape)
|
|
|
+
|
|
|
+ def forward(self, bottom, top):
|
|
|
+ pass
|
|
|
+
|
|
|
+ def backward(self, top, propagate_down, bottom):
|
|
|
+ self.blobs[0].diff[0] = 1
|
|
|
|
|
|
def python_net_file():
|
|
|
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
|
|
|
@@ -52,6 +67,16 @@ def exception_net_file(): |
|
|
return f.name
|
|
|
|
|
|
|
|
|
+def parameter_net_file():
|
|
|
+ with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
|
|
|
+ f.write("""name: 'pythonnet' force_backward: true
|
|
|
+ input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
|
|
|
+ layer { type: 'Python' name: 'layer' bottom: 'data' top: 'top'
|
|
|
+ python_param { module: 'test_python_layer' layer: 'ParameterLayer' } }
|
|
|
+ """)
|
|
|
+ return f.name
|
|
|
+
|
|
|
+
|
|
|
class TestPythonLayer(unittest.TestCase):
|
|
|
def setUp(self):
|
|
|
net_file = python_net_file()
|
|
|
@@ -84,3 +109,32 @@ def test_exception(self): |
|
|
net_file = exception_net_file()
|
|
|
self.assertRaises(RuntimeError, caffe.Net, net_file, caffe.TEST)
|
|
|
os.remove(net_file)
|
|
|
+
|
|
|
+ def test_parameter(self):
|
|
|
+ net_file = parameter_net_file()
|
|
|
+ net = caffe.Net(net_file, caffe.TRAIN)
|
|
|
+ # Test forward and backward
|
|
|
+ net.forward()
|
|
|
+ net.backward()
|
|
|
+ layer = net.layers[list(net._layer_names).index('layer')]
|
|
|
+ self.assertEqual(layer.blobs[0].data[0], 0)
|
|
|
+ self.assertEqual(layer.blobs[0].diff[0], 1)
|
|
|
+ layer.blobs[0].data[0] += layer.blobs[0].diff[0]
|
|
|
+ self.assertEqual(layer.blobs[0].data[0], 1)
|
|
|
+
|
|
|
+ # Test saving and loading
|
|
|
+ h, caffemodel_file = tempfile.mkstemp()
|
|
|
+ net.save(caffemodel_file)
|
|
|
+ layer.blobs[0].data[0] = -1
|
|
|
+ self.assertEqual(layer.blobs[0].data[0], -1)
|
|
|
+ net.copy_from(caffemodel_file)
|
|
|
+ self.assertEqual(layer.blobs[0].data[0], 1)
|
|
|
+ os.remove(caffemodel_file)
|
|
|
+
|
|
|
+ # Test weight sharing
|
|
|
+ net2 = caffe.Net(net_file, caffe.TRAIN)
|
|
|
+ net2.share_with(net)
|
|
|
+ layer = net.layers[list(net2._layer_names).index('layer')]
|
|
|
+ self.assertEqual(layer.blobs[0].data[0], 1)
|
|
|
+
|
|
|
+ os.remove(net_file)
|