|
|
@@ -18,7 +18,7 @@ class -- assign to its attributes directly to name layers, and call |
|
|
are not guaranteed to be forward-compatible.
|
|
|
"""
|
|
|
|
|
|
-from collections import OrderedDict
|
|
|
+from collections import OrderedDict, Counter
|
|
|
|
|
|
from .proto import caffe_pb2
|
|
|
from google import protobuf
|
|
|
@@ -44,10 +44,8 @@ def to_proto(*tops): |
|
|
"""Generate a NetParameter that contains all layers needed to compute
|
|
|
all arguments."""
|
|
|
|
|
|
- if not isinstance(tops, tuple):
|
|
|
- tops = (tops,)
|
|
|
layers = OrderedDict()
|
|
|
- autonames = {}
|
|
|
+ autonames = Counter()
|
|
|
for top in tops:
|
|
|
top.fn._to_proto(layers, {}, autonames)
|
|
|
net = caffe_pb2.NetParameter()
|
|
|
@@ -89,6 +87,9 @@ def to_proto(self): |
|
|
|
|
|
return to_proto(self)
|
|
|
|
|
|
+ def _to_proto(self, layers, names, autonames):
|
|
|
+ return self.fn._to_proto(layers, names, autonames)
|
|
|
+
|
|
|
|
|
|
class Function(object):
|
|
|
"""A Function specifies a layer, its parameters, and its inputs (which
|
|
|
@@ -107,19 +108,26 @@ def __init__(self, type_name, inputs, params): |
|
|
del self.params['in_place']
|
|
|
self.tops = tuple(Top(self, n) for n in range(self.ntop))
|
|
|
|
|
|
- def _get_name(self, top, names, autonames):
|
|
|
+ def _get_name(self, names, autonames):
|
|
|
+ if self not in names and self.ntop > 0:
|
|
|
+ names[self] = self._get_top_name(self.tops[0], names, autonames)
|
|
|
+ elif self not in names:
|
|
|
+ autonames[self.type_name] += 1
|
|
|
+ names[self] = self.type_name + str(autonames[self.type_name])
|
|
|
+ return names[self]
|
|
|
+
|
|
|
+ def _get_top_name(self, top, names, autonames):
|
|
|
if top not in names:
|
|
|
- n = autonames.setdefault(top.fn.type_name, 1)
|
|
|
autonames[top.fn.type_name] += 1
|
|
|
- names[top] = top.fn.type_name + str(n)
|
|
|
+ names[top] = top.fn.type_name + str(autonames[top.fn.type_name])
|
|
|
return names[top]
|
|
|
|
|
|
def _to_proto(self, layers, names, autonames):
|
|
|
if self in layers:
|
|
|
return
|
|
|
bottom_names = []
|
|
|
for inp in self.inputs:
|
|
|
- inp.fn._to_proto(layers, names, autonames)
|
|
|
+ inp._to_proto(layers, names, autonames)
|
|
|
bottom_names.append(layers[inp.fn].top[inp.n])
|
|
|
layer = caffe_pb2.LayerParameter()
|
|
|
layer.type = self.type_name
|
|
|
@@ -129,8 +137,8 @@ def _to_proto(self, layers, names, autonames): |
|
|
layer.top.extend(layer.bottom)
|
|
|
else:
|
|
|
for top in self.tops:
|
|
|
- layer.top.append(self._get_name(top, names, autonames))
|
|
|
- layer.name = self._get_name(self.tops[0], names, autonames)
|
|
|
+ layer.top.append(self._get_top_name(top, names, autonames))
|
|
|
+ layer.name = self._get_name(names, autonames)
|
|
|
|
|
|
for k, v in six.iteritems(self.params):
|
|
|
# special case to handle generic *params
|
|
|
@@ -163,10 +171,10 @@ def __getattr__(self, name): |
|
|
|
|
|
def to_proto(self):
|
|
|
names = {v: k for k, v in six.iteritems(self.tops)}
|
|
|
- autonames = {}
|
|
|
+ autonames = Counter()
|
|
|
layers = OrderedDict()
|
|
|
for name, top in six.iteritems(self.tops):
|
|
|
- top.fn._to_proto(layers, names, autonames)
|
|
|
+ top._to_proto(layers, names, autonames)
|
|
|
net = caffe_pb2.NetParameter()
|
|
|
net.layer.extend(layers.values())
|
|
|
return net
|
|
|
@@ -180,7 +188,9 @@ class Layers(object): |
|
|
def __getattr__(self, name):
|
|
|
def layer_fn(*args, **kwargs):
|
|
|
fn = Function(name, args, kwargs)
|
|
|
- if fn.ntop == 1:
|
|
|
+ if fn.ntop == 0:
|
|
|
+ return fn
|
|
|
+ elif fn.ntop == 1:
|
|
|
return fn.tops[0]
|
|
|
else:
|
|
|
return fn.tops
|
|
|
|