diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 8c4cc3b920d7..47b712c872c3 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -134,20 +134,24 @@ def print_layer_summary(node, out_shape): pre_filter = pre_filter + int(shape[0]) cur_param = 0 if op == 'Convolution': - if ("no_bias" in node["attrs"]) and int(node["attrs"]["no_bias"]): - cur_param = pre_filter * int(node["attrs"]["num_filter"]) + if "no_bias" in node["attrs"] and node["attrs"]["no_bias"] == 'True': + num_group = int(node['attrs'].get('num_group', '1')) + cur_param = pre_filter * int(node["attrs"]["num_filter"]) \ + // num_group for k in _str2tuple(node["attrs"]["kernel"]): cur_param *= int(k) else: - cur_param = pre_filter * int(node["attrs"]["num_filter"]) + num_group = int(node['attrs'].get('num_group', '1')) + cur_param = pre_filter * int(node["attrs"]["num_filter"]) \ + // num_group for k in _str2tuple(node["attrs"]["kernel"]): cur_param *= int(k) cur_param += int(node["attrs"]["num_filter"]) elif op == 'FullyConnected': - if ("no_bias" in node["attrs"]) and int(node["attrs"]["no_bias"]): - cur_param = pre_filter * (int(node["attrs"]["num_hidden"])) + if "no_bias" in node["attrs"] and node["attrs"]["no_bias"] == 'True': + cur_param = pre_filter * int(node["attrs"]["num_hidden"]) else: - cur_param = (pre_filter+1) * (int(node["attrs"]["num_hidden"])) + cur_param = (pre_filter+1) * int(node["attrs"]["num_hidden"]) elif op == 'BatchNorm': key = node["name"] + "_output" if show_shape: