In [29]:
import torch
# from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Model, GPT2Config

# tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
# model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")

# print(model)

In [30]:
# prompt = "GPT2 is a model developed by OpenAI."
# input_ids = tokenizer.encode(prompt, return_tensors="pt")
# attention_mask = torch.ones(input_ids.shape, dtype=torch.long)

# gen_tokens = model.generate(
#     input_ids,
#     do_sample=True,
#     temperature=0.9,
#     max_length=100,
#     attention_mask=attention_mask,
#     pad_token_id=tokenizer.pad_token_id,
# )
# gen_text = tokenizer.batch_decode(gen_tokens)[0]

# print(gen_text)

In [31]:
from transformers import GPT2Config, GPT2Model
from transformers.utils import ModelOutput

config = GPT2Config()
model = GPT2Model(config)

In [32]:
class MetadataTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, data, centered=False, last_modules=[]):
        self = torch.Tensor._make_subclass(cls, data)
        self.centered = centered
        self.last_modules = last_modules
        return self

    def __getattr__(self, name):
        if name == 'centered':
            return False
        elif name == 'last_modules':
            return []
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

    # def __init__(self, data, centered=False):
    #     self.centered = centered

    def __add__(self, other):
        result = super(MetadataTensor, self).__add__(other)
        if isinstance(other, MetadataTensor):
            centered = self.centered and other.centered
        else:
            centered = False
        last_modules = self.last_modules + other.last_modules
        return MetadataTensor(result, centered=centered, last_modules=last_modules)

    def __repr__(self):
        if hasattr(self, "centered"):
            centered = self.centered
        else:
            centered = None
        return f"MetadataTensor({super().__repr__()}, centered={centered})"



In [33]:
ln_cnt = 0
foldable_cnt = 0
center_modules = []

layer_info = []
hooks = []
indent = 0


In [34]:
def apply_func_to_nested_tuple(t, func):
    """
    递归地对嵌套的 tuple 的每个元素应用函数 func，并保持原有的嵌套结构。

    :param t: 一个嵌套的 tuple
    :param func: 要应用到每个非 tuple 元素的函数
    :return: 返回一个结构与原 tuple 相同，但元素经过 func 操作的 tuple
    """
    if isinstance(t, tuple):
        # 如果是 tuple，递归地对每个元素应用函数
        return tuple(apply_func_to_nested_tuple(item, func) for item in t)
    else:
        # 如果不是 tuple，应用函数 func
        return func(t)


def get_shape(t):
    if isinstance(t, torch.Tensor):
        return tuple(t.size())
    else:
        return None


def get_centered(t):
    if isinstance(t, MetadataTensor) and hasattr(t, 'centered'):
        return t.centered
    else:
        return None


def get_last_modules(t):
    if isinstance(t, MetadataTensor) and hasattr(t, 'last_modules'):
        return t.last_modules
    else:
        return []


def hook_pre_fn(module, inputs):
    global indent
    print('  ' * indent, '< ', module.__class__.__name__, '>')
    indent += 1
    for name, sub_module in module.named_children():
        print('  ' * indent, name, ':', sub_module.__class__.__name__)

    inputs_centered = True
    last_modules = []

    def input_func(input):
        nonlocal inputs_centered
        if isinstance(input, MetadataTensor):
            inputs_centered = inputs_centered and input.centered
            last_modules.append(input.last_modules)
        else:
            inputs_centered = False
        print('  ' * indent, '<-', input.__class__.__name__, get_centered(input), get_shape(input), get_last_modules(input))

    apply_func_to_nested_tuple(inputs, input_func)

    module._input_centered = inputs_centered
    module._last_modules = last_modules

    if 'LayerNorm' in module.__class__.__name__:
        global ln_cnt
        ln_cnt += 1
        if inputs_centered:
            global foldable_cnt
            global center_modules
            foldable_cnt += 1
            center_modules.append(module)


def hook_fn(module, inputs, outputs):
    if isinstance(outputs, ModelOutput):
        return outputs
    global indent

    single = False
    if isinstance(outputs, torch.Tensor):
        outputs = (outputs,)
        single = True

    def output_func(output):
        if not isinstance(output, MetadataTensor) and isinstance(output, torch.Tensor):
            output = MetadataTensor(output, centered=False)
        if isinstance(output, MetadataTensor):
            module_name = module.__class__.__name__
            if 'LayerNorm' in module_name or 'Linear' in module_name or 'Conv' in module_name or 'Embedding' in module_name:
                output.centered = True
                output.last_modules = [module]
            elif 'Dropout' in module_name:
                output.centered = module._input_centered
                output.last_modules = module._last_modules
            print('  ' * indent, '->',  output.__class__.__name__, get_centered(output), get_shape(output), get_last_modules(output))

        return output

    new_outputs = apply_func_to_nested_tuple(outputs, output_func)

    indent -= 1
    print('  ' * indent, '</', module.__class__.__name__, '>')

    if single:
        new_outputs = new_outputs[0]
    else:
        new_outputs = tuple(new_outputs)

    return new_outputs


In [35]:
for hook in hooks:
    hook.remove()

for layer in model.named_modules():
    hooks.append(layer[1].register_forward_pre_hook(hook_pre_fn))
    hooks.append(layer[1].register_forward_hook(hook_fn))

input_ids = torch.randint(0, 1000, (1, 128))
my_input_ids = MetadataTensor(input_ids, centered=False)
out = model(my_input_ids)

for hook in hooks:
    hook.remove()

print('LayerNorm:', ln_cnt)
print('Foldable:', foldable_cnt)
print('Center modules:', center_modules)

 <  GPT2Model >
   wte : Embedding
   wpe : Embedding
   drop : Dropout
   h : ModuleList
   ln_f : LayerNorm
   <- MetadataTensor False (1, 128) []
   <  Embedding >
     <- MetadataTensor False (1, 128) []
     -> MetadataTensor True (1, 128, 768) [Embedding(50257, 768)]
   </ Embedding >
   <  Embedding >
     <- Tensor None (1, 128) []
     -> MetadataTensor True (1, 128, 768) [Embedding(1024, 768)]
   </ Embedding >
   <  Dropout >
     <- MetadataTensor True (1, 128, 768) [Embedding(50257, 768), Embedding(1024, 768)]
     -> MetadataTensor True (1, 128, 768) [[Embedding(50257, 768), Embedding(1024, 768)]]
   </ Dropout >
   <  GPT2Block >
     ln_1 : LayerNorm
     attn : GPT2SdpaAttention
     ln_2 : LayerNorm
     mlp : GPT2MLP
     <- MetadataTensor True (1, 128, 768) [[Embedding(50257, 768), Embedding(1024, 768)]]
     <  LayerNorm >
       <- MetadataTensor True (1, 128, 768) [[Embedding(50257, 768), Embedding(1024, 768)]]
       -> MetadataTensor True (1, 128, 768) [LayerNo