In [1]:
import torch
from transformers import GPT2Config, GPT2Model

config = GPT2Config()
model = GPT2Model(config)

import utils
counter = utils.Counter()


In [2]:
for hook in counter.hooks:
    hook.remove()

hook_pre_fn, hook_fn = utils.create_hook_fns(counter)

for layer in model.named_modules():
    counter.hooks.append(layer[1].register_forward_pre_hook(hook_pre_fn))
    counter.hooks.append(layer[1].register_forward_hook(hook_fn))

input_ids = torch.randint(0, 1000, (1, 128))
my_input_ids = utils.MetadataTensor(input_ids, centered=False)
out = model(my_input_ids)

for hook in counter.hooks:
    hook.remove()

print('LayerNorm:', counter.ln_cnt)
print('Foldable:', counter.foldable_cnt)
print('Center modules:', counter.center_modules)

 <  GPT2Model >
   wte : Embedding
   wpe : Embedding
   drop : Dropout
   h : ModuleList
   ln_f : LayerNorm
   <- MetadataTensor False (1, 128) 0 set()
   <  Embedding >
     <- MetadataTensor False (1, 128) 0 set()
     -> MetadataTensor True (1, 128, 768) 1 {Embedding(50257, 768)}
   </ Embedding >
   <  Embedding >
     <- Tensor None (1, 128) 0 set()
     -> MetadataTensor True (1, 128, 768) 1 {Embedding(1024, 768)}
   </ Embedding >
   <  Dropout >
     <- MetadataTensor True (1, 128, 768) 2 {Embedding(1024, 768), Embedding(50257, 768)}
     -> MetadataTensor True (1, 128, 768) 2 {Embedding(1024, 768), Embedding(50257, 768)}
   </ Dropout >
   <  GPT2Block >
     ln_1 : LayerNorm
     attn : GPT2SdpaAttention
     ln_2 : LayerNorm
     mlp : GPT2MLP
     <- MetadataTensor True (1, 128, 768) 2 {Embedding(1024, 768), Embedding(50257, 768)}
     <  LayerNorm >
       <- MetadataTensor True (1, 128, 768) 2 {Embedding(1024, 768), Embedding(50257, 768)}
       -> MetadataTensor True (

In [3]:
import modules

for layer in counter.layernorms:
    modules.replace_layer_norm(layer)

for layer in counter.center_modules:
    modules.center_modules(layer)

print(model)
print(torch.mean(model.h[0].attn.c_attn.weight, dim=-1))


GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): SOLayerNorm()
      (attn): GPT2SdpaAttention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): SOLayerNorm()
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): SOLayerNorm()
)
tensor([ 7.9411e-04,  4.0817e-04,  3.6760e-05,  3.8906e-05,  2.4739e-04,
        -6.6084e-05, -4.0909e-04,  5.5539e-04, -4.7574e-04,  9.6237e-04,
        -2.1080e-04, -6.5660e-04, -8.8870e-05, -9.8644e-06,  3.8810e-04,
        -1.0388e-03, -5.2379e-04, -8.9526e-05, -2.2116e-05, -2.9643e-04,
         1.6156e-04,  4.5462