In [1]:
import torch
from transformers import GPT2Config, GPT2Model, GPT2Tokenizer

import random
import numpy as np

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

config = GPT2Config()
my_ln_model = GPT2Model(config).cuda()
my_so_ln_model = GPT2Model(config).cuda()

my_so_ln_model.load_state_dict(my_ln_model.state_dict())
my_ln_model.eval()
my_so_ln_model.eval()

import utils
counter = utils.Counter()


In [2]:
hook_pre_fn, hook_fn = utils.create_analyse_hook_fns(counter)

input_ids = torch.randint(0, 1000, (1, 128)).cuda()
my_input_ids = utils.MetadataTensor(input_ids, centered=False).cuda()

with utils.HookManager(my_so_ln_model, hook_fn, hook_pre_fn):
    my_so_ln_model(my_input_ids)

print('LayerNorm:', counter.ln_cnt)
print('Foldable:', counter.foldable_cnt)
print('Center modules:', counter.center_modules)

 <  GPT2Model >
   wte : Embedding
   wpe : Embedding
   drop : Dropout
   h : ModuleList
   ln_f : LayerNorm
   <- MetadataTensor False (1, 128) 0 set()
   <  Embedding >
     <- MetadataTensor False (1, 128) 0 set()
     -> MetadataTensor True (1, 128, 768) 1 {Embedding(50257, 768)}
   </ Embedding >
   <  Embedding >
     <- Tensor None (1, 128) 0 set()
     -> MetadataTensor True (1, 128, 768) 1 {Embedding(1024, 768)}
   </ Embedding >
   <  Dropout >
     <- MetadataTensor True (1, 128, 768) 2 {Embedding(50257, 768), Embedding(1024, 768)}
     -> MetadataTensor True (1, 128, 768) 2 {Embedding(50257, 768), Embedding(1024, 768)}
   </ Dropout >
   <  GPT2Block >
     ln_1 : LayerNorm
     attn : GPT2SdpaAttention
     ln_2 : LayerNorm
     mlp : GPT2MLP
     <- MetadataTensor True (1, 128, 768) 2 {Embedding(50257, 768), Embedding(1024, 768)}
     <  LayerNorm >
       <- MetadataTensor True (1, 128, 768) 2 {Embedding(50257, 768), Embedding(1024, 768)}
       -> MetadataTensor True (

  ret = func(*args, **kwargs)


In [3]:
import modules

for layer in counter.layernorms:
    modules.replace_layer_norm_forward(layer)

for layer in counter.center_modules:
    modules.center_modules(layer)


In [4]:
output_queue = []
check = utils.Check()
replace = True

def hook_original(module, input, output):
    name = module.__class__.__name__
    output_queue.append((output, name))

    # if isinstance(output, tuple):
    #     output = output[0]

    # with torch._tensor_str.printoptions(precision=10, sci_mode=True):
    #     len_shape = len(output.shape)
    #     index = tuple([0] * (len_shape - 2) + [slice(None, 4), slice(None, 4)])
    #     print(module.__class__.__name__, output[index])

def check_close_and_replace(tensor_a, tensor_b, check: utils.Check, tensor_a_str, tensor_b_str):
    check.hide_val()
    locals()[tensor_a_str] = tensor_a
    locals()[tensor_b_str] = tensor_b
    if check.check_eq(tensor_a_str, tensor_b_str, abs_tol=1e-2, local_vars=locals()):
        if replace and isinstance(tensor_a, torch.Tensor) and isinstance(tensor_b, torch.Tensor):
            tensor_b.data = tensor_a.data
    check.show_val()

def apply_func_to_nested_tuple_pair(t1, t2, func, *args, **kwargs):
    if isinstance(t1, tuple) and isinstance(t2, tuple):
        return tuple(apply_func_to_nested_tuple_pair(x1, x2, func, *args, **kwargs) for x1, x2 in zip(t1, t2))
    else:
        return func(t1, t2, *args, **kwargs)

def hook_folded(module, input, output):
    folded_name = module.__class__.__name__ + '_folded'
    original_output, original_name = output_queue.pop(0)
    original_name += '_original'
    apply_func_to_nested_tuple_pair(original_output, output, check_close_and_replace, check, original_name, folded_name)

    # if isinstance(output, tuple):
    #     output0 = output[0]

    # with torch._tensor_str.printoptions(precision=10, sci_mode=True):
    #     len_shape = len(output0.shape)
    #     index = tuple([0] * (len_shape - 2) + [slice(None, 4), slice(None, 4)])
    #     print(module.__class__.__name__, output0[index])



In [5]:
with utils.HookManager(my_ln_model, hook_original, None, list(my_ln_model.modules())[1:]):
    original_out = my_ln_model(input_ids)

with utils.HookManager(my_so_ln_model, hook_folded, None, list(my_so_ln_model.modules())[1:]):
    folded_out = my_so_ln_model(input_ids)


[1;34m# 0 [ Test ] Embedding_original ?= Embedding_folded[0m
[1;32m# 0 [ Pass ] Embedding_original == Embedding_folded[0m

[1;34m# 1 [ Test ] Embedding_original ?= Embedding_folded[0m
[1;32m# 1 [ Pass ] Embedding_original == Embedding_folded[0m

[1;34m# 2 [ Test ] Dropout_original ?= Dropout_folded[0m
[1;32m# 2 [ Pass ] Dropout_original == Dropout_folded[0m

[1;34m# 3 [ Test ] LayerNorm_original ?= SOLayerNorm_folded[0m
[1;33mMax diff: 0.13479137420654297
Location: [0, 37, 709]
LayerNorm_original: 2.488018274307251
SOLayerNorm_folded: 2.353226900100708[0m
[1;31m# 3 [ Fail ] LayerNorm_original != SOLayerNorm_folded[0m

[1;34m# 4 [ Test ] Conv1D_original ?= Conv1D_folded[0m
[1;33mMax diff: 0.2537236213684082
Location: [0, 37, 88]
Conv1D_original: -0.9486808776855469
Conv1D_folded: -1.202404499053955[0m
[1;31m# 4 [ Fail ] Conv1D_original != Conv1D_folded[0m

[1;34m# 5 [ Test ] Conv1D_original ?= Conv1D_folded[0m
[1;32m# 5 [ Pass ] Conv1D_original == Conv1D_folde

In [6]:
check.check_eq('folded_out[0]', 'original_out[0]', local_vars=locals(), abs_tol=1e-5)

[1;34m# 220 [ Test ] folded_out[0] ?= original_out[0][0m
[2m=== folded_out[0] ===[0m


tensor([[[1.3577e-01, -1.2654e-01, -3.7514e-01,  ..., 1.1333e-01, 6.4051e-01, -4.3048e-01],
         [5.6186e-01, 4.8840e-01, 7.8146e-01,  ..., -8.9503e-01, 8.5370e-01, -1.2255e+00],
         [7.8186e-01, 7.7026e-01, -2.3901e-01,  ..., 9.3298e-02, 1.0692e+00, -5.3862e-01],
         ...,
         [2.4021e-01, 6.9612e-01, -7.3467e-01,  ..., -2.8962e-01, -3.7554e-01, -2.4757e-01],
         [5.1597e-02, 6.8300e-01, -1.0035e+00,  ..., -1.8950e-01, 1.0808e+00, -5.9560e-01],
         [3.5050e-02, 2.6225e+00, -8.1549e-01,  ..., -1.8690e-01, 1.4773e+00, 7.5001e-01]]],
       device='cuda:0', grad_fn=<ViewBackward0>)
[2m=== original_out[0] ===[0m
tensor([[[2.1712e-02, -1.4150e-01, -3.4320e-01,  ..., 3.9669e-03, 6.5941e-01, -4.2921e-01],
         [5.1380e-01, 4.2810e-01, 7.8693e-01,  ..., -9.6909e-01, 8.7120e-01, -1.2768e+00],
         [7.8011e-01, 7.4773e-01, -2.6297e-01,  ..., 6.7078e-03, 1.0873e+00, -5.5771e-01],
         ...,
         [1.8941e-01, 7.3671e-01, -6.8908e-01,  ..., -2.5222e-01,

False

In [7]:
check.summary()

[1;34m==== < Summary > ====[0m
[1;32m# 0 [ Pass ][0m Embedding_original == Embedding_folded [2m(rel_tol=1e-05, abs_tol=0.01)[0m
[1;32m# 1 [ Pass ][0m Embedding_original == Embedding_folded [2m(rel_tol=1e-05, abs_tol=0.01)[0m
[1;32m# 2 [ Pass ][0m Dropout_original == Dropout_folded [2m(rel_tol=1e-05, abs_tol=0.01)[0m
[1;31m# 3 [ Fail ][0m LayerNorm_original != SOLayerNorm_folded [2m(rel_tol=1e-05, abs_tol=0.01)[0m
[1;31m# 4 [ Fail ][0m Conv1D_original != Conv1D_folded [2m(rel_tol=1e-05, abs_tol=0.01)[0m
[1;32m# 5 [ Pass ][0m Conv1D_original == Conv1D_folded [2m(rel_tol=1e-05, abs_tol=0.01)[0m
[1;32m# 6 [ Pass ][0m Dropout_original == Dropout_folded [2m(rel_tol=1e-05, abs_tol=0.01)[0m
[1;32m# 7 [ Pass ][0m GPT2SdpaAttention_original == GPT2SdpaAttention_folded [2m(rel_tol=1e-05, abs_tol=0.01)[0m
[1;31m# 8 [ Fail ][0m GPT2SdpaAttention_original != GPT2SdpaAttention_folded [2m(rel_tol=1e-05, abs_tol=0.01)[0m
[1;31m# 9 [ Fail ][0m GPT2SdpaAttention_or

In [9]:
native_ln_model = GPT2Model(config).cuda()
my_ln_model = GPT2Model(config).cuda()
native_so_ln_model = GPT2Model(config).cuda()
my_so_ln_model = GPT2Model(config).cuda()

my_ln_model.load_state_dict(native_ln_model.state_dict())
native_so_ln_model.load_state_dict(native_ln_model.state_dict())
my_so_ln_model.load_state_dict(native_ln_model.state_dict())

native_ln_model.eval()
my_ln_model.eval()
native_so_ln_model.eval()
my_so_ln_model.eval()

my_ln_counter = utils.Counter()
native_so_ln_counter = utils.Counter()
my_so_ln_counter = utils.Counter()

# native_ln_hook_pre_fn, native_ln_hook_fn = utils.create_analyse_hook_fns(my_ln_counter, _print=False)
my_ln_hook_pre_fn, my_ln_hook_fn = utils.create_analyse_hook_fns(my_ln_counter, _print=False)
native_so_ln_hook_pre_fn, native_so_ln_hook_fn = utils.create_analyse_hook_fns(my_so_ln_counter, _print=False)
my_so_ln_hook_pre_fn, my_so_ln_hook_fn = utils.create_analyse_hook_fns(my_so_ln_counter, _print=False)

with utils.HookManager(my_ln_model, my_ln_hook_fn, my_ln_hook_pre_fn):
    my_ln_model(my_input_ids)

with utils.HookManager(native_so_ln_model, native_so_ln_hook_pre_fn, native_so_ln_hook_fn):
    native_so_ln_model(my_input_ids)

with utils.HookManager(my_so_ln_model, my_so_ln_hook_fn, my_so_ln_hook_pre_fn):
    my_so_ln_model(my_input_ids)

for layer in my_ln_counter.center_modules:
    modules.center_modules(layer)

for layer in my_so_ln_counter.center_modules:
    modules.center_modules(layer)

for layer in my_so_ln_counter.center_modules:
    modules.center_modules(layer)

for layer in my_ln_counter.layernorms:
    modules.replace_layer_norm_forward(
        layer,
        forward_fn=modules.myln_forward,
        class_name='MyLayerNorm'
    )

for layer in native_so_ln_counter.layernorms:
    modules.replace_layer_norm_forward(
        layer,
        forward_fn=modules.native_soln_forward,
        class_name='NativeSOLayerNorm'
    )

for layer in my_so_ln_counter.layernorms:
    modules.replace_layer_norm_forward(
        layer,
        forward_fn=modules.soln_forward,
        class_name='MySOLayerNorm'
    )


TypeError: create_analyse_hook_fns.<locals>.hook_fn() missing 1 required positional argument: 'outputs'

In [None]:
from deepspeed.profiling.flops_profiler import get_model_profile


[2024-11-20 16:33:25,224] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)


W1120 16:33:27.226000 3840 torch\distributed\elastic\multiprocessing\redirects.py:28] NOTE: Redirects are currently not supported in Windows or MacOs.


In [None]:
flops, macs, params = get_model_profile(
    native_ln_model,
    kwargs={'input_ids': input_ids},
    print_profile=True,
    detailed=True,
)


[2024-11-20 16:33:27,340] [INFO] [profiler.py:1220:get_model_profile] Flops profiler warming-up...
[2024-11-20 16:33:27,455] [INFO] [profiler.py:81:start_profile] Flops profiler started

-------------------------- DeepSpeed Flops Profiler --------------------------
Profile Summary at step 1:
Notations:
data parallel size (dp_size), model parallel size(mp_size),
number of parameters (params), number of multiply-accumulate operations(MACs),
number of floating-point operations (flops), floating-point operations per second (FLOPS),
fwd latency (forward propagation latency), bwd latency (backward propagation latency),
step (weights update latency), iter latency (sum of fwd, bwd and step latency)

params per GPU:                                                         124.44 M
params of model = params per GPU * mp_size:                             0       
fwd MACs per GPU:                                                       11.17 GMACs
fwd flops per GPU:                                   

In [None]:
flops, macs, params = get_model_profile(
    native_so_ln_model,
    kwargs={'input_ids': input_ids},
    print_profile=True,
    detailed=True,
)


[2024-11-20 16:33:27,697] [INFO] [profiler.py:1220:get_model_profile] Flops profiler warming-up...
[2024-11-20 16:33:27,755] [INFO] [profiler.py:81:start_profile] Flops profiler started

-------------------------- DeepSpeed Flops Profiler --------------------------
Profile Summary at step 1:
Notations:
data parallel size (dp_size), model parallel size(mp_size),
number of parameters (params), number of multiply-accumulate operations(MACs),
number of floating-point operations (flops), floating-point operations per second (FLOPS),
fwd latency (forward propagation latency), bwd latency (backward propagation latency),
step (weights update latency), iter latency (sum of fwd, bwd and step latency)

params per GPU:                                                         124.44 M
params of model = params per GPU * mp_size:                             0       
fwd MACs per GPU:                                                       11.48 GMACs
fwd flops per GPU:                                   

In [None]:
flops, macs, params = get_model_profile(
    my_ln_model,
    kwargs={'input_ids': input_ids},
    print_profile=True,
    detailed=True,
)


[2024-11-20 16:33:27,946] [INFO] [profiler.py:1220:get_model_profile] Flops profiler warming-up...
[2024-11-20 16:33:28,037] [INFO] [profiler.py:81:start_profile] Flops profiler started

-------------------------- DeepSpeed Flops Profiler --------------------------
Profile Summary at step 1:
Notations:
data parallel size (dp_size), model parallel size(mp_size),
number of parameters (params), number of multiply-accumulate operations(MACs),
number of floating-point operations (flops), floating-point operations per second (FLOPS),
fwd latency (forward propagation latency), bwd latency (backward propagation latency),
step (weights update latency), iter latency (sum of fwd, bwd and step latency)

params per GPU:                                                         124.44 M
params of model = params per GPU * mp_size:                             0       
fwd MACs per GPU:                                                       11.78 GMACs
fwd flops per GPU:                                   

In [None]:
flops, macs, params = get_model_profile(
    my_so_ln_model,
    kwargs={'input_ids': input_ids},
    print_profile=True,
    detailed=True,
)


[2024-11-20 16:33:28,230] [INFO] [profiler.py:1220:get_model_profile] Flops profiler warming-up...
[2024-11-20 16:33:28,304] [INFO] [profiler.py:81:start_profile] Flops profiler started

-------------------------- DeepSpeed Flops Profiler --------------------------
Profile Summary at step 1:
Notations:
data parallel size (dp_size), model parallel size(mp_size),
number of parameters (params), number of multiply-accumulate operations(MACs),
number of floating-point operations (flops), floating-point operations per second (FLOPS),
fwd latency (forward propagation latency), bwd latency (backward propagation latency),
step (weights update latency), iter latency (sum of fwd, bwd and step latency)

params per GPU:                                                         124.44 M
params of model = params per GPU * mp_size:                             0       
fwd MACs per GPU:                                                       12.08 GMACs
fwd flops per GPU:                                   

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity, schedule

my_schedule = schedule(
    wait=100,
    warmup=50,
    active=100,
)

torch.cuda.empty_cache()

my_so_ln_model.eval()
my_ln_model.eval()

with torch.no_grad():
    with profile(
        activities=[
            ProfilerActivity.CPU, ProfilerActivity.CUDA
        ],
        schedule=my_schedule
    ) as prof:
        with record_function("my_so_ln_model_inference"):
            for _ in range(1000):
                my_so_ln_model(input_ids)
                prof.step()
    print(prof.key_averages().table())
    # prof.export_chrome_trace("tmp/folded_trace.json")

    with profile(
        activities=[
            ProfilerActivity.CPU, ProfilerActivity.CUDA
        ],
        schedule=my_schedule
    ) as prof:
        with record_function("my_ln_model_inference"):
            for _ in range(1000):
                my_ln_model(input_ids)
                prof.step()
    print(prof.key_averages().table())
    # prof.export_chrome_trace("tmp/original_trace.json")


-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                    ProfilerStep*        50.45%        5.614s       100.00%       11.128s      27.819ms        2.724s        24.65%       11.050s      27.624ms           400  
                                       aten::view         1.46%     162.228ms         1.46%     162.228ms       2.704us     284.151ms         2.57%     284.151ms       4.736us         59999  
                                     at