Handle Python layer exceptions correctly #2462

Merged
merged 3 commits into from Aug 6, 2015
@@ -18,44 +18,23 @@ class PythonLayer : public Layer<Dtype> {
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
- try {
- self_.attr("setup")(bottom, top);
- } catch (bp::error_already_set) {
- PyErr_Print();
- throw;
- }
+ self_.attr("setup")(bottom, top);
}
-
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
- try {
- self_.attr("reshape")(bottom, top);
- } catch (bp::error_already_set) {
- PyErr_Print();
- throw;
- }
+ self_.attr("reshape")(bottom, top);
}
virtual inline const char* type() const { return "Python"; }
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
- try {
- self_.attr("forward")(bottom, top);
- } catch (bp::error_already_set) {
- PyErr_Print();
- throw;
- }
+ self_.attr("forward")(bottom, top);
}
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
- try {
- self_.attr("backward")(top, propagate_down, bottom);
- } catch (bp::error_already_set) {
- PyErr_Print();
- throw;
- }
+ self_.attr("backward")(top, propagate_down, bottom);
}
private:
@@ -21,6 +21,13 @@ def backward(self, top, propagate_down, bottom):
bottom[0].diff[...] = 10 * top[0].diff
+class ExceptionLayer(caffe.Layer):
+ """A layer for checking exceptions from Python"""
+
+ def setup(self, bottom, top):
+ raise RuntimeError
+
+
def python_net_file():
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write("""name: 'pythonnet' force_backward: true
@@ -34,6 +41,16 @@ def python_net_file():
return f.name
+def exception_net_file():
+ with tempfile.NamedTemporaryFile(delete=False) as f:
+ f.write("""name: 'pythonnet' force_backward: true
+ input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
+ layer { type: 'Python' name: 'layer' bottom: 'data' top: 'top'
+ python_param { module: 'test_python_layer' layer: 'ExceptionLayer' } }
+ """)
+ return f.name
+
+
class TestPythonLayer(unittest.TestCase):
def setUp(self):
net_file = python_net_file()
@@ -61,3 +78,8 @@ def test_reshape(self):
for blob in self.net.blobs.itervalues():
for d in blob.data.shape:
self.assertEqual(s, d)
+
+ def test_exception(self):
+ net_file = exception_net_file()
+ self.assertRaises(RuntimeError, caffe.Net, net_file, caffe.TEST)
+ os.remove(net_file)
View
@@ -8,6 +8,11 @@
#include "boost/algorithm/string.hpp"
#include "caffe/caffe.hpp"
+#ifdef WITH_PYTHON_LAYER
+#include "boost/python.hpp"
+namespace bp = boost::python;
+#endif
+
using caffe::Blob;
using caffe::Caffe;
using caffe::Net;
@@ -304,7 +309,16 @@ int main(int argc, char** argv) {
// Run tool or show usage.
caffe::GlobalInit(&argc, &argv);
if (argc == 2) {
- return GetBrewFunction(caffe::string(argv[1]))();
+#ifdef WITH_PYTHON_LAYER
+ try {
+#endif
+ return GetBrewFunction(caffe::string(argv[1]))();
+#ifdef WITH_PYTHON_LAYER
+ } catch (bp::error_already_set) {
+ PyErr_Print();
+ return 1;
+ }
+#endif
} else {
gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");
}