In [2]:
%load_ext autoreload
%autoreload 2

import os, sys
import torch
from transformers import AutoModel, AutoConfig
module_path = (os.path.join('../src'))
sys.path.insert(0, module_path)
from shrank import auto_fact


2024-11-25 14:39:03.576335: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())

# Init Model

In [4]:
config = AutoConfig.from_pretrained('EleutherAI/pythia-160m')
# model = AutoModel(config=config)
model = AutoModel.from_pretrained("EleutherAI/pythia-160m")

In [5]:
print(model)
print(f'Params (total): {count_param(model)}')
print(f'Params (trainable): {count_param(model, trainable=True)}')

GPTNeoXModel(
  (embed_in): Embedding(50304, 768)
  (emb_dropout): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x GPTNeoXLayer(
      (input_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (post_attention_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (post_attention_dropout): Dropout(p=0.0, inplace=False)
      (post_mlp_dropout): Dropout(p=0.0, inplace=False)
      (attention): GPTNeoXSdpaAttention(
        (rotary_emb): GPTNeoXRotaryEmbedding()
        (query_key_value): Linear(in_features=768, out_features=2304, bias=True)
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (attention_dropout): Dropout(p=0.0, inplace=False)
      )
      (mlp): GPTNeoXMLP(
        (dense_h_to_4h): Linear(in_features=768, out_features=3072, bias=True)
        (dense_4h_to_h): Linear(in_features=3072, out_features=768, bias=True)
        (act): GELUActivation()
      )
    )
  )
  (final_layer_norm): La

# Factorize Model

### Apply absolute rank

In [6]:
%%time
fact_model = auto_fact(model, rank=64, deepcopy=True)

print(fact_model)
print(f'Params (total): {count_param(fact_model)}')
print(f'Params (trainable): {count_param(fact_model, trainable=True)}')

GPTNeoXModel(
  (embed_in): Embedding(50304, 768)
  (emb_dropout): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x GPTNeoXLayer(
      (input_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (post_attention_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (post_attention_dropout): Dropout(p=0.0, inplace=False)
      (post_mlp_dropout): Dropout(p=0.0, inplace=False)
      (attention): GPTNeoXSdpaAttention(
        (rotary_emb): GPTNeoXRotaryEmbedding()
        (query_key_value): LED(
          (led_unit): Sequential(
            (0): Linear(in_features=768, out_features=64, bias=False)
            (1): Linear(in_features=64, out_features=2304, bias=True)
          )
        )
        (dense): LED(
          (led_unit): Sequential(
            (0): Linear(in_features=768, out_features=64, bias=False)
            (1): Linear(in_features=64, out_features=768, bias=True)
          )
        )
        (attention_dropout): D

### Apply percentage rank

In [9]:
%%time
pct_fact_model = auto_fact(model, rank=0.25, deepcopy=True)
print(pct_fact_model)
print(f'Params (total): {count_param(pct_fact_model)}')
print(f'Params (trainable): {count_param(pct_fact_model, trainable=True)}')

GPTNeoXModel(
  (embed_in): Embedding(50304, 768)
  (emb_dropout): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x GPTNeoXLayer(
      (input_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (post_attention_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (post_attention_dropout): Dropout(p=0.0, inplace=False)
      (post_mlp_dropout): Dropout(p=0.0, inplace=False)
      (attention): GPTNeoXSdpaAttention(
        (rotary_emb): GPTNeoXRotaryEmbedding()
        (query_key_value): LED(
          (led_unit): Sequential(
            (0): Linear(in_features=768, out_features=144, bias=False)
            (1): Linear(in_features=144, out_features=2304, bias=True)
          )
        )
        (dense): LED(
          (led_unit): Sequential(
            (0): Linear(in_features=768, out_features=96, bias=False)
            (1): Linear(in_features=96, out_features=768, bias=True)
          )
        )
        (attention_dropout):

### Apply factorization only on specific modules

In [11]:
# Only factorize last 6 transformer layers and the pooler layer of the model
factorizable_submodules = list(model.transformer.h[6:])

In [None]:
%%time
fact_model = auto_fact(model, rank=0.2, deepcopy=True, submodules=factorizable_submodules)
print(f'Params (total): {count_param(model)}')
print(f'Params (trainable): {count_param(model, trainable=True)}')

CPU times: user 568 ms, sys: 36 ms, total: 604 ms
Wall time: 87.7 ms


90414336

In [None]:
%%time
fact_model = auto_fact(model, rank=0.2, deepcopy=True, submodules=factorizable_submodules)
count_param(fact_model)

CPU times: user 19 s, sys: 628 ms, total: 19.6 s
Wall time: 2.68 s


90414336

# Speed test on CPU

### Test Inference CPU

In [15]:
%%timeit
with torch.no_grad():
    y = model(torch.zeros(32,256, dtype=torch.long))

2.85 s ± 548 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
%%timeit
with torch.no_grad():
    y = fact_model(torch.zeros(32,256, dtype=torch.long))

2.42 s ± 423 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Test Forward-Backward CPU

In [17]:
%%timeit
y = model(torch.zeros(8,256, dtype=torch.long))
y.logits.sum().backward()

1.83 s ± 223 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
%%timeit
y = fact_model(torch.zeros(8,256, dtype=torch.long))
y.logits.sum().backward()

1.49 s ± 10 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Speed test on GPU

### Move models to GPU

In [26]:
model = model.cuda()
fact_model = fact_model.cuda()

### Test Inference GPU

In [27]:
x = torch.zeros(16,256, dtype=torch.long).cuda()

In [28]:
%%timeit
with torch.no_grad():
    y = model(x)

The slowest run took 8.52 times longer than the fastest. This could mean that an intermediate result is being cached.
65.7 ms ± 44.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [32]:
%%timeit
with torch.no_grad():
    y = fact_model(x)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x768 and 2304x64)

### Test Forward-Backward GPU

In [23]:
x = torch.zeros(8,256, dtype=torch.long).cuda()

In [24]:
%%timeit
y = model(x)
y.logits.sum().backward()

275 ms ± 4.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [25]:
%%timeit
y = fact_model(x)
y.logits.sum().backward()

238 ms ± 2.56 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
