Skip to content

Commit

Permalink
Add phase support for draw net
Browse files Browse the repository at this point in the history
  • Loading branch information
cdoersch committed Jul 5, 2016
1 parent f28f5ae commit ed642a2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
29 changes: 24 additions & 5 deletions python/caffe/draw.py
Expand Up @@ -127,7 +127,7 @@ def choose_color_by_layertype(layertype):
return color


def get_pydot_graph(caffe_net, rankdir, label_edges=True):
def get_pydot_graph(caffe_net, rankdir, label_edges=True, phase=None):
"""Create a data structure which represents the `caffe_net`.
Parameters
Expand All @@ -137,6 +137,9 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
Direction of graph layout.
label_edges : boolean, optional
Label the edges (default is True).
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
Include layers from this network phase. If None, include all layers.
(the default is None)
Returns
-------
Expand All @@ -148,6 +151,16 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
pydot_nodes = {}
pydot_edges = []
for layer in caffe_net.layer:
if phase is not None:
included = False
if len(layer.include) == 0:
included = True
for layer_phase in layer.include:
included = included or layer_phase.phase == phase
for layer_phase in layer.exclude:
included = included and not layer_phase.phase == phase
if not included:
continue
node_label = get_layer_label(layer, rankdir)
node_name = "%s_%s" % (layer.name, layer.type)
if (len(layer.bottom) == 1 and len(layer.top) == 1 and
Expand Down Expand Up @@ -186,7 +199,7 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
return pydot_graph


def draw_net(caffe_net, rankdir, ext='png'):
def draw_net(caffe_net, rankdir, ext='png', phase=None):
"""Draws a caffe net and returns the image string encoded using the given
extension.
Expand All @@ -195,16 +208,19 @@ def draw_net(caffe_net, rankdir, ext='png'):
caffe_net : a caffe.proto.caffe_pb2.NetParameter protocol buffer.
ext : string, optional
The image extension (the default is 'png').
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
Include layers from this network phase. If None, include all layers.
(the default is None)
Returns
-------
string :
Postscript representation of the graph.
"""
return get_pydot_graph(caffe_net, rankdir).create(format=ext)
return get_pydot_graph(caffe_net, rankdir, phase=phase).create(format=ext)


def draw_net_to_file(caffe_net, filename, rankdir='LR'):
def draw_net_to_file(caffe_net, filename, rankdir='LR', phase=None):
"""Draws a caffe net, and saves it to file using the format given as the
file extension. Use '.raw' to output raw text that you can manually feed
to graphviz to draw graphs.
Expand All @@ -216,7 +232,10 @@ def draw_net_to_file(caffe_net, filename, rankdir='LR'):
The path to a file where the networks visualization will be stored.
rankdir : {'LR', 'TB', 'BT'}
Direction of graph layout.
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
Include layers from this network phase. If None, include all layers.
(the default is None)
"""
ext = filename[filename.rfind('.')+1:]
with open(filename, 'wb') as fid:
fid.write(draw_net(caffe_net, rankdir, ext))
fid.write(draw_net(caffe_net, rankdir, ext, phase))
15 changes: 14 additions & 1 deletion python/draw_net.py
Expand Up @@ -28,6 +28,11 @@ def parse_args():
'http://www.graphviz.org/doc/info/'
'attrs.html#k:rankdir'),
default='LR')
parser.add_argument('--phase',
help=('Which network phase to draw: can be TRAIN, '
'TEST, or ALL. If ALL, then all layers are drawn '
'regardless of phase.'),
default="ALL")

args = parser.parse_args()
return args
Expand All @@ -38,7 +43,15 @@ def main():
net = caffe_pb2.NetParameter()
text_format.Merge(open(args.input_net_proto_file).read(), net)
print('Drawing net to %s' % args.output_image_file)
caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir)
phase=None;
if args.phase == "TRAIN":
phase = caffe.TRAIN
elif args.phase == "TEST":
phase = caffe.TEST
elif args.phase != "ALL":
raise ValueError("Unknown phase: " + args.phase)
caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir,
phase)


if __name__ == '__main__':
Expand Down

0 comments on commit ed642a2

Please sign in to comment.