Allow the python layer have attribute "phase" #3995

Merged
merged 2 commits into from May 4, 2016
Jump to file or symbol
Failed to load files and symbols.
+27 −0
Split
@@ -26,6 +26,7 @@ class PythonLayer : public Layer<Dtype> {
}
self_.attr("param_str") = bp::str(
this->layer_param_.python_param().param_str());
+ self_.attr("phase") = static_cast<int>(this->phase_);
self_.attr("setup")(bottom, top);
}
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
@@ -44,6 +44,18 @@ def forward(self, bottom, top):
def backward(self, top, propagate_down, bottom):
self.blobs[0].diff[0] = 1
+class PhaseLayer(caffe.Layer):
+ """A layer for checking attribute `phase`"""
+
+ def setup(self, bottom, top):
+ pass
+
+ def reshape(self, bootom, top):
+ top[0].reshape()
+
+ def forward(self, bottom, top):
+ top[0].data[()] = self.phase
+
def python_net_file():
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
f.write("""name: 'pythonnet' force_backward: true
@@ -76,6 +88,14 @@ def parameter_net_file():
""")
return f.name
+def phase_net_file():
+ with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
+ f.write("""name: 'pythonnet' force_backward: true
+ layer { type: 'Python' name: 'layer' top: 'phase'
+ python_param { module: 'test_python_layer' layer: 'PhaseLayer' } }
+ """)
+ return f.name
+
@unittest.skipIf('Python' not in caffe.layer_type_list(),
'Caffe built without Python layer support')
@@ -140,3 +160,9 @@ def test_parameter(self):
self.assertEqual(layer.blobs[0].data[0], 1)
os.remove(net_file)
+
+ def test_phase(self):
+ net_file = phase_net_file()
+ for phase in caffe.TRAIN, caffe.TEST:
+ net = caffe.Net(net_file, phase)
+ self.assertEqual(net.forward()['phase'], phase)