In [9]:
import torch
from transformers import GPT2Config, GPT2Model

import random
import numpy as np

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

config = GPT2Config()
original_model = GPT2Model(config)
folded_model = GPT2Model(config)

folded_model.load_state_dict(original_model.state_dict())
original_model.eval()
folded_model.eval()

import utils
counter = utils.Counter()


In [10]:
hook_pre_fn, hook_fn = utils.create_analyse_hook_fns(counter)

input_ids = torch.randint(0, 1000, (1, 128))
my_input_ids = utils.MetadataTensor(input_ids, centered=False)

with utils.HookManager(folded_model, hook_fn, hook_pre_fn):
    folded_model(my_input_ids)

print('LayerNorm:', counter.ln_cnt)
print('Foldable:', counter.foldable_cnt)
print('Center modules:', counter.center_modules)

 <  GPT2Model >
   wte : Embedding
   wpe : Embedding
   drop : Dropout
   h : ModuleList
   ln_f : LayerNorm
   <- MetadataTensor False (1, 128) 0 set()
   <  Embedding >
     <- MetadataTensor False (1, 128) 0 set()
     -> MetadataTensor True (1, 128, 768) 1 {Embedding(50257, 768)}
   </ Embedding >
   <  Embedding >
     <- Tensor None (1, 128) 0 set()
     -> MetadataTensor True (1, 128, 768) 1 {Embedding(1024, 768)}
   </ Embedding >
   <  Dropout >
     <- MetadataTensor True (1, 128, 768) 2 {Embedding(50257, 768), Embedding(1024, 768)}
     -> MetadataTensor True (1, 128, 768) 2 {Embedding(50257, 768), Embedding(1024, 768)}
   </ Dropout >
   <  GPT2Block >
     ln_1 : LayerNorm
     attn : GPT2SdpaAttention
     ln_2 : LayerNorm
     mlp : GPT2MLP
     <- MetadataTensor True (1, 128, 768) 2 {Embedding(50257, 768), Embedding(1024, 768)}
     <  LayerNorm >
       <- MetadataTensor True (1, 128, 768) 2 {Embedding(50257, 768), Embedding(1024, 768)}
       -> MetadataTensor True (

In [11]:
import modules

for layer in counter.layernorms:
    modules.replace_layer_norm(layer)

for layer in counter.center_modules:
    modules.center_modules(layer)

print(folded_model)
print(torch.mean(folded_model.h[0].attn.c_attn.weight, dim=-1))


GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): SOLayerNorm()
      (attn): GPT2SdpaAttention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): SOLayerNorm()
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): SOLayerNorm()
)
tensor([-8.8657e-04,  2.7004e-04, -5.3564e-04,  2.2696e-04,  5.7918e-05,
        -1.8646e-04,  6.6418e-04,  3.9907e-04,  1.9836e-04, -6.2640e-04,
        -1.2385e-03,  6.2525e-04,  2.1106e-04,  4.6794e-04,  3.6536e-04,
        -4.8721e-05, -1.9707e-04, -2.0062e-04,  5.5034e-04, -7.2874e-04,
        -2.3469e-04,  2.4639

In [36]:
def hook_out_fn(module, input, output):
    if isinstance(output, tuple):
        output = output[0]
    len_shape = len(output.shape)
    index = tuple([0] * (len_shape - 2) + [slice(None, 4), slice(None, 4)])

    torch.set_printoptions(precision=10, sci_mode=True)
    print(module.__class__.__name__, output[index])
    torch.set_printoptions(profile="default")


In [38]:
input_ids = torch.randint(0, 1000, (1, 128))

with utils.HookManager(folded_model, hook_out_fn, None, [folded_model.h[1].attn]):
    folded_out = folded_model(input_ids)

with utils.HookManager(original_model, hook_out_fn, None, [original_model.h[1].attn]):
    original_out = original_model(input_ids)


GPT2SdpaAttention tensor([[1.6968969256e-02, -2.6908591390e-02, -2.0419303328e-03, -8.7223825976e-03],
        [6.9546517916e-03, 2.3677349091e-02, -2.9388369992e-02, 7.7289049514e-03],
        [-2.5100285187e-02, 3.4999854863e-02, -3.0305774882e-02, 2.9334910214e-03],
        [-4.2850956321e-02, 3.4401997924e-02, -2.1867046133e-02, -1.0952060111e-02]],
       grad_fn=<SliceBackward0>)
GPT2SdpaAttention tensor([[1.5742979944e-02, -2.1947111934e-02, -3.1891353428e-03, -9.2024542391e-03],
        [5.3768004291e-03, 2.7812674642e-02, -3.0345985666e-02, 7.9994052649e-03],
        [-2.6222892106e-02, 3.9662204683e-02, -3.0745055526e-02, 3.7268549204e-03],
        [-4.3906118721e-02, 3.9651136845e-02, -2.1979814395e-02, -1.0125691071e-02]],
       grad_fn=<SliceBackward0>)
