|
|
@@ -276,6 +276,22 @@ def _Net_batch(self, blobs): |
|
|
padding])
|
|
|
yield padded_batch
|
|
|
|
|
|
+
|
|
|
+class _Net_IdNameWrapper:
|
|
|
+ """
|
|
|
+ A simple wrapper that allows the ids propery to be accessed as a dict
|
|
|
+ indexed by names. Used for top and bottom names
|
|
|
+ """
|
|
|
+ def __init__(self, net, func):
|
|
|
+ self.net, self.func = net, func
|
|
|
+
|
|
|
+ def __getitem__(self, name):
|
|
|
+ # Map the layer name to id
|
|
|
+ ids = self.func(self.net, list(self.net._layer_names).index(name))
|
|
|
+ # Map the blob id to name
|
|
|
+ id_to_name = list(self.net.blobs)
|
|
|
+ return [id_to_name[i] for i in ids]
|
|
|
+
|
|
|
# Attach methods to Net.
|
|
|
Net.blobs = _Net_blobs
|
|
|
Net.blob_loss_weights = _Net_blob_loss_weights
|
|
|
@@ -288,3 +304,5 @@ def _Net_batch(self, blobs): |
|
|
Net._batch = _Net_batch
|
|
|
Net.inputs = _Net_inputs
|
|
|
Net.outputs = _Net_outputs
|
|
|
+Net.top_names = property(lambda n: _Net_IdNameWrapper(n, Net._top_ids))
|
|
|
+Net.bottom_names = property(lambda n: _Net_IdNameWrapper(n, Net._bottom_ids))
|