Exposing layer top and bottom names to python #2865

Merged
merged 1 commit into from Jan 8, 2016
Jump to file or symbol
Failed to load files and symbols.
+34 −0
Split
View
@@ -149,6 +149,18 @@ class Net {
inline const vector<vector<Blob<Dtype>*> >& top_vecs() const {
return top_vecs_;
}
+ /// @brief returns the ids of the top blobs of layer i
+ inline const vector<int> & top_ids(int i) const {
+ CHECK_GE(i, 0) << "Invalid layer id";
+ CHECK_LT(i, top_id_vecs_.size()) << "Invalid layer id";
+ return top_id_vecs_[i];
+ }
+ /// @brief returns the ids of the bottom blobs of layer i
+ inline const vector<int> & bottom_ids(int i) const {
+ CHECK_GE(i, 0) << "Invalid layer id";
+ CHECK_LT(i, bottom_id_vecs_.size()) << "Invalid layer id";
+ return bottom_id_vecs_[i];
+ }
inline const vector<vector<bool> >& bottom_need_backward() const {
return bottom_need_backward_;
}
View
@@ -232,6 +232,10 @@ BOOST_PYTHON_MODULE(_caffe) {
.def("share_with", &Net<Dtype>::ShareTrainedLayersWith)
.add_property("_blob_loss_weights", bp::make_function(
&Net<Dtype>::blob_loss_weights, bp::return_internal_reference<>()))
+ .def("_bottom_ids", bp::make_function(&Net<Dtype>::bottom_ids,
+ bp::return_value_policy<bp::copy_const_reference>()))
+ .def("_top_ids", bp::make_function(&Net<Dtype>::top_ids,
+ bp::return_value_policy<bp::copy_const_reference>()))
.add_property("_blobs", bp::make_function(&Net<Dtype>::blobs,
bp::return_internal_reference<>()))
.add_property("layers", bp::make_function(&Net<Dtype>::layers,
View
@@ -276,6 +276,22 @@ def _Net_batch(self, blobs):
padding])
yield padded_batch
+
+class _Net_IdNameWrapper:
+ """
+ A simple wrapper that allows the ids propery to be accessed as a dict
+ indexed by names. Used for top and bottom names
+ """
+ def __init__(self, net, func):
+ self.net, self.func = net, func
+
+ def __getitem__(self, name):
+ # Map the layer name to id
+ ids = self.func(self.net, list(self.net._layer_names).index(name))
+ # Map the blob id to name
+ id_to_name = list(self.net.blobs)
+ return [id_to_name[i] for i in ids]
+
# Attach methods to Net.
Net.blobs = _Net_blobs
Net.blob_loss_weights = _Net_blob_loss_weights
@@ -288,3 +304,5 @@ def _Net_batch(self, blobs):
Net._batch = _Net_batch
Net.inputs = _Net_inputs
Net.outputs = _Net_outputs
+Net.top_names = property(lambda n: _Net_IdNameWrapper(n, Net._top_ids))
+Net.bottom_names = property(lambda n: _Net_IdNameWrapper(n, Net._bottom_ids))