|
|
@@ -44,6 +44,18 @@ def forward(self, bottom, top): |
|
|
def backward(self, top, propagate_down, bottom):
|
|
|
self.blobs[0].diff[0] = 1
|
|
|
|
|
|
+class PhaseLayer(caffe.Layer):
|
|
|
+ """A layer for checking attribute `phase`"""
|
|
|
+
|
|
|
+ def setup(self, bottom, top):
|
|
|
+ pass
|
|
|
+
|
|
|
+ def reshape(self, bootom, top):
|
|
|
+ top[0].reshape()
|
|
|
+
|
|
|
+ def forward(self, bottom, top):
|
|
|
+ top[0].data[()] = self.phase
|
|
|
+
|
|
|
def python_net_file():
|
|
|
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
|
|
|
f.write("""name: 'pythonnet' force_backward: true
|
|
|
@@ -76,6 +88,14 @@ def parameter_net_file(): |
|
|
""")
|
|
|
return f.name
|
|
|
|
|
|
+def phase_net_file():
|
|
|
+ with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
|
|
|
+ f.write("""name: 'pythonnet' force_backward: true
|
|
|
+ layer { type: 'Python' name: 'layer' top: 'phase'
|
|
|
+ python_param { module: 'test_python_layer' layer: 'PhaseLayer' } }
|
|
|
+ """)
|
|
|
+ return f.name
|
|
|
+
|
|
|
|
|
|
@unittest.skipIf('Python' not in caffe.layer_type_list(),
|
|
|
'Caffe built without Python layer support')
|
|
|
@@ -140,3 +160,9 @@ def test_parameter(self): |
|
|
self.assertEqual(layer.blobs[0].data[0], 1)
|
|
|
|
|
|
os.remove(net_file)
|
|
|
+
|
|
|
+ def test_phase(self):
|
|
|
+ net_file = phase_net_file()
|
|
|
+ for phase in caffe.TRAIN, caffe.TEST:
|
|
|
+ net = caffe.Net(net_file, phase)
|
|
|
+ self.assertEqual(net.forward()['phase'], phase)
|