add layer_dict to the python interface #4347

Merged
merged 2 commits into from Feb 17, 2017
Jump to file or symbol
Failed to load files and symbols.
+18 −0
Split
View
@@ -43,6 +43,16 @@ def _Net_blob_loss_weights(self):
self._blob_loss_weights))
return self._blob_loss_weights_dict
+@property
+def _Net_layer_dict(self):
+ """
+ An OrderedDict (bottom to top, i.e., input to output) of network
+ layers indexed by name
+ """
+ if not hasattr(self, '_layer_dict'):
+ self._layer_dict = OrderedDict(zip(self._layer_names, self.layers))
+ return self._layer_dict
+
@property
def _Net_params(self):
@@ -311,6 +321,7 @@ def __getitem__(self, name):
# Attach methods to Net.
Net.blobs = _Net_blobs
Net.blob_loss_weights = _Net_blob_loss_weights
+Net.layer_dict = _Net_layer_dict
Net.params = _Net_params
Net.forward = _Net_forward
Net.backward = _Net_backward
@@ -59,6 +59,13 @@ def test_memory(self):
for bl in blobs:
total += bl.data.sum() + bl.diff.sum()
+ def test_layer_dict(self):
+ layer_dict = self.net.layer_dict
+ self.assertEqual(list(layer_dict.keys()), list(self.net._layer_names))
+ for i, name in enumerate(self.net._layer_names):
+ self.assertEqual(layer_dict[name].type,
+ self.net.layers[i].type)
+
def test_forward_backward(self):
self.net.forward()
self.net.backward()