In [1]:
import torch
from transformers import GPT2Config, GPT2Model

import random
import numpy as np

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

config = GPT2Config()
original_model = GPT2Model(config).cuda()
folded_model = GPT2Model(config).cuda()

folded_model.load_state_dict(original_model.state_dict())
original_model.eval()
folded_model.eval()

import utils
counter = utils.Counter()


In [2]:
hook_pre_fn, hook_fn = utils.create_analyse_hook_fns(counter)

input_ids = torch.randint(0, 1000, (1, 128)).cuda()
my_input_ids = utils.MetadataTensor(input_ids, centered=False).cuda()

with utils.HookManager(folded_model, hook_fn, hook_pre_fn):
    folded_model(my_input_ids)

print('LayerNorm:', counter.ln_cnt)
print('Foldable:', counter.foldable_cnt)
print('Center modules:', counter.center_modules)

 <  GPT2Model >
   wte : Embedding
   wpe : Embedding
   drop : Dropout
   h : ModuleList
   ln_f : LayerNorm
   <- MetadataTensor False (1, 128) 0 set()
   <  Embedding >
     <- MetadataTensor False (1, 128) 0 set()
     -> MetadataTensor True (1, 128, 768) 1 {Embedding(50257, 768)}
   </ Embedding >
   <  Embedding >
     <- Tensor None (1, 128) 0 set()
     -> MetadataTensor True (1, 128, 768) 1 {Embedding(1024, 768)}
   </ Embedding >
   <  Dropout >
     <- MetadataTensor True (1, 128, 768) 2 {Embedding(50257, 768), Embedding(1024, 768)}
     -> MetadataTensor True (1, 128, 768) 2 {Embedding(50257, 768), Embedding(1024, 768)}
   </ Dropout >
   <  GPT2Block >
     ln_1 : LayerNorm
     attn : GPT2SdpaAttention
     ln_2 : LayerNorm
     mlp : GPT2MLP
     <- MetadataTensor True (1, 128, 768) 2 {Embedding(50257, 768), Embedding(1024, 768)}
     <  LayerNorm >
       <- MetadataTensor True (1, 128, 768) 2 {Embedding(50257, 768), Embedding(1024, 768)}
       -> MetadataTensor True (

  ret = func(*args, **kwargs)


In [3]:
import modules

for layer in counter.layernorms:
    modules.replace_layer_norm_forward(layer)

for layer in counter.center_modules:
    modules.center_modules(layer)


In [4]:
output_queue = []
check = utils.Check()
replace = False

def hook_original(module, input, output):
    name = module.__class__.__name__
    output_queue.append((output, name))

    # if isinstance(output, tuple):
    #     output = output[0]

    # with torch._tensor_str.printoptions(precision=10, sci_mode=True):
    #     len_shape = len(output.shape)
    #     index = tuple([0] * (len_shape - 2) + [slice(None, 4), slice(None, 4)])
    #     print(module.__class__.__name__, output[index])

def check_close_and_replace(tensor_a, tensor_b, check: utils.Check, tensor_a_str, tensor_b_str):
    check.hide_val()
    locals()[tensor_a_str] = tensor_a
    locals()[tensor_b_str] = tensor_b
    if check.check_eq(tensor_a_str, tensor_b_str, abs_tol=1e-5, local_vars=locals()):
        if replace and isinstance(tensor_a, torch.Tensor) and isinstance(tensor_b, torch.Tensor):
            tensor_b.data = tensor_a.data
    check.show_val()

def apply_func_to_nested_tuple_pair(t1, t2, func, *args, **kwargs):
    if isinstance(t1, tuple) and isinstance(t2, tuple):
        return tuple(apply_func_to_nested_tuple_pair(x1, x2, func, *args, **kwargs) for x1, x2 in zip(t1, t2))
    else:
        return func(t1, t2, *args, **kwargs)

def hook_folded(module, input, output):
    folded_name = module.__class__.__name__ + '_folded'
    original_output, original_name = output_queue.pop(0)
    original_name += '_original'
    apply_func_to_nested_tuple_pair(original_output, output, check_close_and_replace, check, original_name, folded_name)

    # if isinstance(output, tuple):
    #     output0 = output[0]

    # with torch._tensor_str.printoptions(precision=10, sci_mode=True):
    #     len_shape = len(output0.shape)
    #     index = tuple([0] * (len_shape - 2) + [slice(None, 4), slice(None, 4)])
    #     print(module.__class__.__name__, output0[index])



In [5]:
with utils.HookManager(original_model, hook_original, None, list(original_model.modules())[1:]):
    original_out = original_model(input_ids)

with utils.HookManager(folded_model, hook_folded, None, list(folded_model.modules())[1:]):
    folded_out = folded_model(input_ids)


[1;34m# 0 [ Test ] Embedding_original ?= Embedding_folded[0m
[1;33mMean abs diff: 0.0005862714606337249[0m
[1;31m# 0 [ Fail ] Embedding_original != Embedding_folded[0m

[1;34m# 1 [ Test ] Embedding_original ?= Embedding_folded[0m
[1;33mMean abs diff: 0.0006422363221645355[0m
[1;31m# 1 [ Fail ] Embedding_original != Embedding_folded[0m

[1;34m# 2 [ Test ] Dropout_original ?= Dropout_folded[0m
[1;33mMean abs diff: 0.0008070007897913456[0m
[1;31m# 2 [ Fail ] Dropout_original != Dropout_folded[0m

[1;34m# 3 [ Test ] LayerNorm_original ?= SOLayerNorm_folded[0m
[1;32m# 3 [ Pass ] LayerNorm_original == SOLayerNorm_folded[0m

[1;34m# 4 [ Test ] Conv1D_original ?= Conv1D_folded[0m
[1;32m# 4 [ Pass ] Conv1D_original == Conv1D_folded[0m

[1;34m# 5 [ Test ] Conv1D_original ?= Conv1D_folded[0m
[1;33mMean abs diff: 0.0002815908519551158[0m
[1;31m# 5 [ Fail ] Conv1D_original != Conv1D_folded[0m

[1;34m# 6 [ Test ] Dropout_original ?= Dropout_folded[0m
[1;33mMean abs

In [6]:
check.check_eq('folded_out[0]', 'original_out[0]', local_vars=locals(), abs_tol=1e-5)

[1;34m# 220 [ Test ] folded_out[0] ?= original_out[0][0m
[2m=== folded_out[0] ===[0m
tensor([[[2.1711e-02, -1.4150e-01, -3.4320e-01,  ..., 3.9684e-03, 6.5941e-01, -4.2921e-01],
         [5.1380e-01, 4.2810e-01, 7.8693e-01,  ..., -9.6909e-01, 8.7120e-01, -1.2768e+00],
         [7.8011e-01, 7.4773e-01, -2.6297e-01,  ..., 6.7087e-03, 1.0873e+00, -5.5771e-01],
         ...,
         [1.8941e-01, 7.3671e-01, -6.8908e-01,  ..., -2.5222e-01, -3.3449e-01, -2.5983e-01],
         [3.0729e-02, 6.8014e-01, -1.0200e+00,  ..., -1.9081e-01, 1.1007e+00, -6.1009e-01],
         [3.8300e-02, 2.6348e+00, -8.0384e-01,  ..., -1.8258e-01, 1.4888e+00, 7.4053e-01]]],
       device='cuda:0', grad_fn=<ViewBackward0>)
[2m=== original_out[0] ===[0m
tensor([[[2.1712e-02, -1.4150e-01, -3.4320e-01,  ..., 3.9669e-03, 6.5941e-01, -4.2921e-01],
         [5.1380e-01, 4.2810e-01, 7.8693e-01,  ..., -9.6909e-01, 8.7120e-01, -1.2768e+00],
         [7.8011e-01, 7.4773e-01, -2.6297e-01,  ..., 6.7078e-03, 1.0873e+00, -5.5

True

In [7]:
check.summary()

[1;34m==== < Summary > ====[0m
[1;31m# 0 [ Fail ][0m Embedding_original != Embedding_folded [2m(rel_tol=1e-05, abs_tol=1e-05)[0m
[1;31m# 1 [ Fail ][0m Embedding_original != Embedding_folded [2m(rel_tol=1e-05, abs_tol=1e-05)[0m
[1;31m# 2 [ Fail ][0m Dropout_original != Dropout_folded [2m(rel_tol=1e-05, abs_tol=1e-05)[0m
[1;32m# 3 [ Pass ][0m LayerNorm_original == SOLayerNorm_folded [2m(rel_tol=1e-05, abs_tol=1e-05)[0m
[1;32m# 4 [ Pass ][0m Conv1D_original == Conv1D_folded [2m(rel_tol=1e-05, abs_tol=1e-05)[0m
[1;31m# 5 [ Fail ][0m Conv1D_original != Conv1D_folded [2m(rel_tol=1e-05, abs_tol=1e-05)[0m
[1;31m# 6 [ Fail ][0m Dropout_original != Dropout_folded [2m(rel_tol=1e-05, abs_tol=1e-05)[0m
[1;31m# 7 [ Fail ][0m GPT2SdpaAttention_original != GPT2SdpaAttention_folded [2m(rel_tol=1e-05, abs_tol=1e-05)[0m
[1;32m# 8 [ Pass ][0m GPT2SdpaAttention_original == GPT2SdpaAttention_folded [2m(rel_tol=1e-05, abs_tol=1e-05)[0m
[1;32m# 9 [ Pass ][0m GPT2SdpaAtt

In [8]:
original_model = GPT2Model(config).cuda()
folded_model = GPT2Model(config).cuda()

folded_model.load_state_dict(original_model.state_dict())
original_model.eval()
folded_model.eval()

hook_pre_fn, hook_fn = utils.create_analyse_hook_fns(counter, _print=False)

original_counter = utils.Counter()
folded_counter = utils.Counter()

with utils.HookManager(original_model, hook_fn, hook_pre_fn):
    original_model(my_input_ids)

with utils.HookManager(folded_model, hook_fn, hook_pre_fn):
    folded_model(my_input_ids)

for layer in original_counter.layernorms:
    modules.replace_layer_norm_forward(layer, forward_fn=modules.myln_forward)

for layer in folded_counter.layernorms:
    modules.replace_layer_norm_forward(layer, forward_fn=modules.soln_forward)

for layer in folded_counter.center_modules:
    modules.center_modules(layer)


In [13]:
from torch.profiler import profile, record_function, ProfilerActivity

with torch.no_grad():
    with profile(activities=[
        ProfilerActivity.CPU, ProfilerActivity.CUDA
    ]) as prof:
        with record_function("original_model_inference"):
            for _ in range(100):
                original_model(input_ids)
    print(prof.key_averages().table())
    prof.export_chrome_trace("original_trace.json")

with torch.no_grad():
    with profile(activities=[
        ProfilerActivity.CPU, ProfilerActivity.CUDA
    ]) as prof:
        with record_function("folded_model_inference"):
            for _ in range(100):
                folded_model(input_ids)
    print(prof.key_averages().table())
    prof.export_chrome_trace("folded_trace.json")


-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         original_model_inference        46.18%        3.649s       100.00%        7.900s        7.900s     450.348ms         5.68%        7.933s        7.933s             1  
                                       aten::view         1.16%      91.986ms         1.16%      91.986ms       4.599us     174.355ms         2.20%     174.355ms       8.718us         20000  
                                     at