A Python script for at-a-glance net summary #3090
Merged
Jump to file or symbol
Failed to load files and symbols.
| @@ -0,0 +1,140 @@ | ||
| +#!/usr/bin/env python | ||
| + | ||
| +"""Net summarization tool. | ||
| + | ||
| +This tool summarizes the structure of a net in a concise but comprehensive | ||
| +tabular listing, taking a prototxt file as input. | ||
| + | ||
| +Use this tool to check at a glance that the computation you've specified is the | ||
| +computation you expect. | ||
| +""" | ||
| + | ||
| +from caffe.proto import caffe_pb2 | ||
| +from google import protobuf | ||
| +import re | ||
| +import argparse | ||
| + | ||
| +# ANSI codes for coloring blobs (used cyclically) | ||
| +COLORS = ['92', '93', '94', '95', '97', '96', '42', '43;30', '100', | ||
| + '444', '103;30', '107;30'] | ||
| +DISCONNECTED_COLOR = '41' | ||
| + | ||
| +def read_net(filename): | ||
| + net = caffe_pb2.NetParameter() | ||
| + with open(filename) as f: | ||
| + protobuf.text_format.Parse(f.read(), net) | ||
| + return net | ||
| + | ||
| +def format_param(param): | ||
| + out = [] | ||
| + if len(param.name) > 0: | ||
| + out.append(param.name) | ||
| + if param.lr_mult != 1: | ||
| + out.append('x{}'.format(param.lr_mult)) | ||
| + if param.decay_mult != 1: | ||
| + out.append('Dx{}'.format(param.decay_mult)) | ||
| + return ' '.join(out) | ||
| + | ||
| +def printed_len(s): | ||
| + return len(re.sub(r'\033\[[\d;]+m', '', s)) | ||
| + | ||
| +def print_table(table, max_width): | ||
| + """Print a simple nicely-aligned table. | ||
| + | ||
| + table must be a list of (equal-length) lists. Columns are space-separated, | ||
| + and as narrow as possible, but no wider than max_width. Text may overflow | ||
| + columns; note that unlike string.format, this will not affect subsequent | ||
| + columns, if possible.""" | ||
| + | ||
| + max_widths = [max_width] * len(table[0]) | ||
| + column_widths = [max(printed_len(row[j]) + 1 for row in table) | ||
| + for j in range(len(table[0]))] | ||
| + column_widths = [min(w, max_w) for w, max_w in zip(column_widths, max_widths)] | ||
| + | ||
| + for row in table: | ||
| + row_str = '' | ||
| + right_col = 0 | ||
| + for cell, width in zip(row, column_widths): | ||
| + right_col += width | ||
| + row_str += cell + ' ' | ||
| + row_str += ' ' * max(right_col - printed_len(row_str), 0) | ||
| + print row_str | ||
| + | ||
| +def summarize_net(net): | ||
| + disconnected_tops = set() | ||
| + for lr in net.layer: | ||
| + disconnected_tops |= set(lr.top) | ||
| + disconnected_tops -= set(lr.bottom) | ||
| + | ||
| + table = [] | ||
| + colors = {} | ||
| + for lr in net.layer: | ||
| + tops = [] | ||
| + for ind, top in enumerate(lr.top): | ||
| + color = colors.setdefault(top, COLORS[len(colors) % len(COLORS)]) | ||
| + if top in disconnected_tops: | ||
| + top = '\033[1;4m' + top | ||
| + if len(lr.loss_weight) > 0: | ||
| + top = '{} * {}'.format(lr.loss_weight[ind], top) | ||
| + tops.append('\033[{}m{}\033[0m'.format(color, top)) | ||
| + top_str = ', '.join(tops) | ||
| + | ||
| + bottoms = [] | ||
| + for bottom in lr.bottom: | ||
| + color = colors.get(bottom, DISCONNECTED_COLOR) | ||
| + bottoms.append('\033[{}m{}\033[0m'.format(color, bottom)) | ||
| + bottom_str = ', '.join(bottoms) | ||
| + | ||
| + if lr.type == 'Python': | ||
| + type_str = lr.python_param.module + '.' + lr.python_param.layer | ||
| + else: | ||
| + type_str = lr.type | ||
| + | ||
| + # Summarize conv/pool parameters. | ||
| + # TODO support rectangular/ND parameters | ||
| + conv_param = lr.convolution_param | ||
| + if (lr.type in ['Convolution', 'Deconvolution'] | ||
| + and len(conv_param.kernel_size) == 1): | ||
| + arg_str = str(conv_param.kernel_size[0]) | ||
| + if len(conv_param.stride) > 0 and conv_param.stride[0] != 1: | ||
| + arg_str += '/' + str(conv_param.stride[0]) | ||
| + if len(conv_param.pad) > 0 and conv_param.pad[0] != 0: | ||
| + arg_str += '+' + str(conv_param.pad[0]) | ||
| + arg_str += ' ' + str(conv_param.num_output) | ||
| + if conv_param.group != 1: | ||
| + arg_str += '/' + str(conv_param.group) | ||
| + elif lr.type == 'Pooling': | ||
| + arg_str = str(lr.pooling_param.kernel_size) | ||
| + if lr.pooling_param.stride != 1: | ||
| + arg_str += '/' + str(lr.pooling_param.stride) | ||
| + if lr.pooling_param.pad != 0: | ||
| + arg_str += '+' + str(lr.pooling_param.pad) | ||
| + else: | ||
| + arg_str = '' | ||
| + | ||
| + if len(lr.param) > 0: | ||
| + param_strs = map(format_param, lr.param) | ||
| + if max(map(len, param_strs)) > 0: | ||
| + param_str = '({})'.format(', '.join(param_strs)) | ||
| + else: | ||
| + param_str = '' | ||
| + else: | ||
| + param_str = '' | ||
| + | ||
| + table.append([lr.name, type_str, param_str, bottom_str, '->', top_str, | ||
| + arg_str]) | ||
| + return table | ||
| + | ||
| +def main(): | ||
| + parser = argparse.ArgumentParser(description="Print a concise summary of net computation.") | ||
| + parser.add_argument('filename', help='net prototxt file to summarize') | ||
| + parser.add_argument('-w', '--max-width', help='maximum field width', | ||
| + type=int, default=30) | ||
| + args = parser.parse_args() | ||
| + | ||
| + net = read_net(args.filename) | ||
| + table = summarize_net(net) | ||
| + print_table(table, max_width=args.max_width) | ||
| + | ||
| +if __name__ == '__main__': | ||
| + main() |