In [18]:
import torch
from transformers import BertModel, BertTokenizer
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

inputs = tokenizer("Hello, I love transformers!", return_tensors="pt")

# 维护缩进层级
indent_level = 0
module_stack = []
logs = []
levels = []

def shape_repr(x):
    if isinstance(x, torch.Tensor):
        return tuple(x.shape)
    elif isinstance(x, (list, tuple)):
        return [tuple(t.shape) for t in x if isinstance(t, torch.Tensor)]
    elif isinstance(x, (BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions)):
        shapes = []
        for v in x.__dict__.values():
            if isinstance(v, torch.Tensor):
                shapes.append(tuple(v.shape))
            elif isinstance(v, (list, tuple)):
                shapes.extend([tuple(t.shape) for t in v if isinstance(t, torch.Tensor)])
        return shapes
    return None

def pre_hook(module, inputs):
    global indent_level
    module_stack.append(module.__class__.__name__)
    prefix = "    " * indent_level
    in_shape = shape_repr(inputs)
    log_str = f"{prefix}{module.__class__.__name__} 输入: {in_shape}"
    # print(log_str)
    # logs.append(log_str)
    logs.append({
        "prefix": prefix,
        "module": module.__class__.__name__,
        "shape": in_shape,
        "type": "输入"
    })
    levels.append(indent_level)
    indent_level += 1

def post_hook(module, inputs, output):
    global indent_level
    indent_level -= 1
    module_stack.pop()
    prefix = "    " * indent_level
    out_shape = shape_repr(output)
    # print(f"{prefix}{module.__class__.__name__} 输出: {out_shape}")
    log_str = f"{prefix}{module.__class__.__name__} 输出: {out_shape}"
    # print(log_str)
    logs.append({
        "prefix": prefix,
        "module": module.__class__.__name__,
        "shape": out_shape,
        "type": "输出"
    })
    levels.append(indent_level)

# 给所有子模块注册 pre_hook 和 post_hook
for name, module in model.named_modules():
    if name != "":
        module.register_forward_pre_hook(pre_hook)
        module.register_forward_hook(post_hook)

# 前向传播
with torch.no_grad():
    outputs = model(**inputs)

# 合并相同模块的日志
for i in range(0, len(logs)):
    if i + 1 <= len(logs) - 1 and levels[i] == levels[i + 1]:
        print(f"{logs[i]['prefix']}{logs[i]['module']} 输入: {logs[i]['shape']} -> 输出: {logs[i + 1]['shape']}")
    elif i - 1 >= 0 and levels[i] == levels[i - 1]:
        continue
    else:
        print(f"{logs[i]['prefix']}{logs[i]['module']} {logs[i]['type']}: {logs[i]['shape']}")

BertEmbeddings 输入: []
    Embedding 输入: [(1, 8)] -> 输出: (1, 8, 768)
    Embedding 输入: (1, 8, 768) -> 输出: [(1, 8)]
    Embedding 输入: [(1, 8)] -> 输出: (1, 8, 768)
    Embedding 输入: (1, 8, 768) -> 输出: [(1, 8)]
    Embedding 输入: [(1, 8)] -> 输出: (1, 8, 768)
    Embedding 输入: (1, 8, 768) -> 输出: [(1, 8, 768)]
    LayerNorm 输入: [(1, 8, 768)] -> 输出: (1, 8, 768)
    LayerNorm 输入: (1, 8, 768) -> 输出: [(1, 8, 768)]
    Dropout 输入: [(1, 8, 768)] -> 输出: (1, 8, 768)
BertEmbeddings 输入: (1, 8, 768) -> 输出: [(1, 8, 768)]
    BertLayer 输入: [(1, 8, 768)]
        BertAttention 输入: [(1, 8, 768)]
            BertSdpaSelfAttention 输入: [(1, 8, 768)]
                Linear 输入: [(1, 8, 768)] -> 输出: (1, 8, 768)
                Linear 输入: (1, 8, 768) -> 输出: [(1, 8, 768)]
                Linear 输入: [(1, 8, 768)] -> 输出: (1, 8, 768)
                Linear 输入: (1, 8, 768) -> 输出: [(1, 8, 768)]
                Linear 输入: [(1, 8, 768)] -> 输出: (1, 8, 768)
            BertSdpaSelfAttention 输入: [(1, 8, 768)] -> 输出: [(1, 8, 768

In [15]:
print(inputs.keys())
print(inputs.input_ids)
print(inputs.token_type_ids)
print(inputs.attention_mask)

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
tensor([[  101,  7592,  1010,  1045,  2293, 19081,   999,   102]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1]])


In [26]:
import torch
from transformers import BertModel, BertTokenizer

# 用于存储每个模块的输入输出和 class
module_info = {}
module_children = {}

# 构建模块树
def build_children_dict(module, parent_name=""):
    children = list(module.named_children())
    module_children[parent_name] = [name for name, _ in children]
    for name, child in children:
        full_name = f"{parent_name}.{name}" if parent_name else name
        build_children_dict(child, full_name)

# Hook 函数记录输入输出和 class
def hook_fn(module, input, output, name):
    module_info[name] = {
        "class": module.__class__.__name__,
        "input": [i.shape if isinstance(i, torch.Tensor) else type(i) for i in input],
        "output": output.shape if isinstance(output, torch.Tensor) else type(output),
    }

# 递归打印模块输入输出
def print_module_info(module_name, indent=0):
    indent_str = "  " * indent
    info = module_info.get(module_name)

    children = module_children.get(module_name, [])
    if info:
        if not children:
            print(f"{indent_str}{info['class']} ({module_name}) input: {info['input']} -> output: {info['output']}")
        else:
            # 有子模块，先打印 input
            print(f"{indent_str}{info['class']} ({module_name}) input: {info['input']}")
    for child_name in children:
        full_child_name = f"{module_name}.{child_name}" if module_name else child_name
        print_module_info(full_child_name, indent + 1)
    if info and children:
        print(f"{indent_str}{info['class']} ({module_name}) output: {info['output']}")

# 初始化模型和 tokenizer
model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)

# 构建模块树
build_children_dict(model)

# 注册 hook
for name, module in model.named_modules():
    module.register_forward_hook(lambda m, i, o, n=name: hook_fn(m, i, o, n))

# 构造输入
# text = "Hello, this is a test."
# inputs = tokenizer(text, return_tensors="pt")
sentences = ["Hello, I love transformers!", "BERT is a powerful model."]
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)

# 前向传播
with torch.no_grad():
    output = model(**inputs)

# 打印模块信息
print_module_info("")


BertModel () input: []
  BertEmbeddings (embeddings) input: []
    Embedding (embeddings.word_embeddings) input: [torch.Size([2, 8])] -> output: torch.Size([2, 8, 768])
    Embedding (embeddings.position_embeddings) input: [torch.Size([1, 8])] -> output: torch.Size([1, 8, 768])
    Embedding (embeddings.token_type_embeddings) input: [torch.Size([2, 8])] -> output: torch.Size([2, 8, 768])
    LayerNorm (embeddings.LayerNorm) input: [torch.Size([2, 8, 768])] -> output: torch.Size([2, 8, 768])
    Dropout (embeddings.dropout) input: [torch.Size([2, 8, 768])] -> output: torch.Size([2, 8, 768])
  BertEmbeddings (embeddings) output: torch.Size([2, 8, 768])
  BertEncoder (encoder) input: [torch.Size([2, 8, 768])]
      BertLayer (encoder.layer.0) input: [torch.Size([2, 8, 768]), <class 'NoneType'>, <class 'NoneType'>, <class 'NoneType'>, <class 'NoneType'>, <class 'NoneType'>, <class 'bool'>]
        BertAttention (encoder.layer.0.attention) input: [torch.Size([2, 8, 768]), <class 'NoneType'>

In [27]:
print(inputs)

{'input_ids': tensor([[  101,  7592,  1010,  1045,  2293, 19081,   999,   102],
        [  101, 14324,  2003,  1037,  3928,  2944,  1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1]])}
