# Torchmeter

An `all-in-one` tool for `Pytorch` model analysis, measuring:
- Params,
- FLOPs / MACs (aka. MACC or MADD), 
- Memory cost, 
- Inference time
- Throughput

Project: https://github.com/Ahzyuan/torchmeter

In [1]:
# install 
# %pip install torchmeter

In [2]:
import torch
from rich import print
from torchvision import models
from torchmeter import Meter, get_config

# 1. 模型准备与Meter初始化
model = models.vgg19_bn()
metered_model = Meter(model)
cfg = get_config()

# 2. 零侵入代理
metered_model.features.requires_grad_(False)

# 3. 输入处理与推理（展示设备同步）
input = torch.randn(1, 3, 224, 224)
if torch.cuda.is_available():
    metered_model.to('cuda')
output = metered_model(input)  # 标准前向传播

In [3]:
# 4. 模型结构分析
# --------------------------------------------
## 启用重复块折叠
print("="*10, " enable smart folding of repeat blocks", "="*10)
metered_model.tree_fold_repeat = True
print(metered_model.structure)

## 不启用
print("="*10, " disable smart folding of repeat blocks", "="*10)
metered_model.tree_fold_repeat = False
print(metered_model.structure)

In [4]:
# 5. 性能分析（展示全栈分析能力）
cfg.render_interval = 0
# --------------------------------------------
# 5.1 参数分析（展示参数测量）
print("="*10, " Parameter Analysis ", "="*10)
print(metered_model.param)  # 总参数统计
tb, data = metered_model.profile('param', no_tree=True)  # 分层参数分布

In [5]:
# 5.2 计算分析（展示FLOPs/MACs）
print("="*10, " Computational Profiling ", "="*10)
print(metered_model.cal)
tb, data = metered_model.profile('cal', no_tree=True)

In [6]:
# 5.3 内存分析（展示内存诊断）
print("="*10, " Memory Diagnostics ", "="*10)
print(metered_model.mem)
tb, data = metered_model.profile('mem', no_tree=True)

In [7]:
# 5.4 性能基准测试（展示延迟和吞吐量）
print("="*10, " Inference latency & Throughput benchmarking ", "="*10)
metered_model.ittp_warmup = 10  # 定制预热次数
metered_model.ittp_benchmark_time = 20
print(metered_model.ittp)
tb, data = metered_model.profile('ittp', no_tree=True)

Warming Up: 100%|██████████| 10/10 [00:00<00:00, 398.08it/s]
Benchmark Inference Time & Throughput: 100%|██████████| 1280/1280 [00:00<00:00, 5445.62module/s]


Warming Up: 100%|██████████| 10/10 [00:00<00:00, 430.13it/s]
Benchmark Inference Time & Throughput: 100%|██████████| 1280/1280 [00:00<00:00, 4705.27module/s]


In [8]:
# 6. 可视化功能（展示实时渲染与配置）
# --------------------------------------------
# 6.1 结构树渲染自定义
from rich.box import ROUNDED

print("="*10, " Tree rendering customization ", "="*10)
metered_model.tree_fold_repeat = True
metered_model.tree_levels_args = {
    "default": {"label": "[b gray35](<node_id>) [green]<name>[/green] [cyan]<module_repr>[/]"},
    "1": {"guide_style": "cornflower_blue"}
}
metered_model.tree_repeat_block_args = {
    "title": "[[b]<repeat_time>[/b]] [i]Times Repeated[/]",
    "box": ROUNDED
}
print(metered_model.structure)

In [9]:
# 6.2 表格渲染自定义
print("="*10, " Table rendering customization ", "="*10)
metered_model.table_column_args.justify = "left"
metered_model.table_display_args = {
    "style": "#af8700", # or rgb(175,135,0)
    "show_lines": True,
    "show_edge": False    
}

tb, data = metered_model.profile("param", no_tree=True)

In [10]:
# 6.2 表格内容渲染自定义
# 6.2.1 树 + 表
print("="*10, " Table report with tree ", "="*10)
cfg.restore()
cfg.render_interval = 0
tb, data = metered_model.profile("param", 
                                 no_tree=False)

In [11]:
# 6.2.2 树 + 表
print("="*10, " Table report with raw data ", "="*10)
tb, data = metered_model.profile("param", no_tree=True, raw_data=True)

In [12]:
# 6.2.3 表格列自定义
print("="*10, " Table structure customization ", "="*10)
tb, data = metered_model.profile(
    "mem", 
    no_tree=True, 
    pick_cols=["Operation_Id", "Operation_Name", "Param_Cost", "Buffer_Cost", "Output_Cost", "Total"], 
    exclude_cols=["Operation_Name"],
    custom_cols={"Operation_Id": "ID", 
                 "Param_Cost": "Param Cost", 
                 "Buffer_Cost": "Buffer Cost", 
                 "Output_Cost": "Output Cost"},
    keep_custom_name = True,
    newcol_name="Index",
    newcol_func=lambda df: list(range(len(df))),
    newcol_type=int,
    newcol_idx=0,
    keep_new_col=True
)

# check the new columns are kept
print(metered_model.table_cols("mem"))

In [13]:
# 6.3 Programmable tabular report
cfg.restore() # restore all customized settings
cfg.render_interval = 0

def newcol_logic(df):
    num_col = df['Number']
    return num_col.map_elements(
        lambda x: f"{100 * x / metered_model.param.TotalNum:.4f} %",
        return_dtype=str
    )

print("="*10, " Programmable tabular report ", "="*10)
origin_col = metered_model.table_cols('param')
print(f"origin cols: {origin_col}")

tb, data = metered_model.profile(
    'param', 
    no_tree = True,
    exclude_cols=["Operation_Name"],
    custom_cols={"Operation_Id": 'ID', 
                 "Param_Name": 'Param Name', 
                 "Requires_Grad": 'Trainable',
                 "Numeric_Num": "Number"},
    newcol_name='Percentage',
    newcol_func=newcol_logic,
    newcol_type=str
)

In [14]:
# 6.4 tabular report export
print("="*10, " Tabular report export ", "="*10)
tb, data = metered_model.profile(
    'param', 
    show=False,
    no_tree = True,
    exclude_cols=["Operation_Name"],
    custom_cols={"Operation_Id": 'ID', 
                 "Param_Name": 'Param Name', 
                 "Requires_Grad": 'Trainable',
                 "Numeric_Num": "Number"},
    newcol_name='Percentage',
    newcol_func=newcol_logic,
    newcol_type=str,
    save_to='./param_report.xlsx' # or csv
)

In [15]:
# 7. Cross-Platform Support
print("="*10, " Cross-Platform Support ", "="*10)

metered_model.to("cpu")
print(metered_model.device)

if torch.cuda.is_available():
    metered_model.device = "cuda:0"
    print(metered_model.device)

In [16]:
# 8. Model summary

print("="*10, " Model Information ", "="*10)
print(metered_model.model_info)

In [17]:
# 9. Statistics Overview
print("="*10, " Statistics Overview ", "="*10)
print(metered_model.overview())

print("="*10, " Statistics Overview (no warnings) ", "="*10)
print(metered_model.overview(show_warning=False))

print("="*10, " Statistics Overview (custom) ", "="*10)
print(metered_model.overview("param", "mem"))

Warming Up: 100%|██████████| 10/10 [00:00<00:00, 423.15it/s]


Benchmark Inference Time & Throughput: 100%|██████████| 1280/1280 [00:00<00:00, 4652.85module/s]


Warming Up: 100%|██████████| 10/10 [00:00<00:00, 429.08it/s]
Benchmark Inference Time & Throughput: 100%|██████████| 1280/1280 [00:00<00:00, 4717.00module/s]


In [18]:
# 10. Advanced Usage

## 10.1 post export of tabular report
print("="*10, " Custom export ", "="*10)
metered_model.table_renderer.export(df=data,
                                    save_path=".",
                                    format="csv",
                                    file_suffix="custom_export",
                                    raw_data=True)

In [19]:
## 10.2 repeat footer
import torch.nn as nn
from random import sample

print("="*10, " Custom repeat block footer ", "="*10)

class RepeatModel(nn.Module):
    def __init__(self, repeat_winsz:int=1, repeat_time:int=2):
        super(RepeatModel, self).__init__()
        
        layer_candidates = [nn.Linear(10, 10), 
                            nn.ReLU(),
                            nn.Identity()]

        pick_modules = sample(layer_candidates, repeat_winsz)
        all_modules = pick_modules * repeat_time

        self.layers = nn.ModuleList(all_modules)

metered_model = Meter(RepeatModel(repeat_winsz=2, repeat_time=3), 
                      device="cpu")

### 10.2.1 change with a hard-coding string 
print("-"*10, " Using hard-coding string ", "-"*10)
metered_model.tree_renderer.repeat_footer = "My custom footer"
print(metered_model.structure)

In [20]:
### 10.2.2 change with a string which access operation node attributes
print("-"*10, " Using dynamic string with attributes resolved ", "-"*10)
metered_model.tree_renderer.repeat_footer = "This module type is <type>"
print(metered_model.structure)

In [21]:
### 10.2.3 change with a function which accept a attr-dictionary and return a string
print("-"*10, " Using funtion ", "-"*10)
def my_footer(attr_dict):
    repeat_win_size = attr_dict["repeat_winsz"]
    if repeat_win_size > 1:
        return f"There are {repeat_win_size} modules in a repeat window"
    else:
        return "The repeat window only contains one module"
metered_model.tree_renderer.repeat_footer = my_footer
print(metered_model.structure)

In [22]:
## 10.3 config management

print("="*10, " Efficient config management ", "="*10)
### 10.3.1 show config
print("-"*10, " config display ", "-"*10)
print(cfg)

In [23]:
### 10.3.2 retrieve config setting
print("-"*10, " config settings retrieval ", "-"*10)
print(
    f"config_file: {cfg.config_file}\n",
    f"render time interval: {cfg.render_interval}\n",
    f"tree default guide line style: {cfg.tree_levels_args.default.guide_style}\n",
    f"table col justify: {cfg.table_column_args.justify}\n",
    f"gap between tree and table in profiling: {cfg.combine.horizon_gap}"
)

In [24]:
### 10.3.3 change config settings
print("-"*10, " change config settings ", "-"*10)
cfg.render_interval = 0.5
cfg.tree_levels_args.default.guide_style = "buld"
cfg.table_display_args = {
    "show_header": False,
    "show_lines": True
}
print(cfg.render_interval)
print(cfg.tree_levels_args.default.guide_style)
print(cfg.table_display_args.show_header)
print(cfg.table_display_args.show_lines)

In [25]:
import os
### 10.3.4 save config
print("-"*10, " dump config settings ", "-"*10)
des = "./my_config.yaml"
cfg.dump(save_path=des)
abs_des = os.path.abspath(des)
if os.path.exists(abs_des):
    print(f"config dumped successfully to {abs_des}")

In [26]:
### 10.3.5 restore config
print("-"*10, " restore all config settings ", "-"*10)
cfg.restore()
print(cfg)

In [27]:
### 10.3.6 reuse config
print("-"*10, " config reuse ", "-"*10)
new_cfg = get_config(config_path=abs_des)
print(f"reuse: {new_cfg.config_file}")
print(new_cfg)
print("You can compare with the restored settings in last cell, ",
      "to check if the settings before restore are reused.")