In [48]:
%load_ext autoreload
%autoreload 2

import os, sys
import torch
from transformers import AutoModel, AutoConfig
module_path = (os.path.join('../src'))
sys.path.insert(0, module_path)
from shrank import auto_fact


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [49]:
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())

# Init Model

In [50]:
# config = AutoConfig.from_pretrained('EleutherAI/pythia-410m')
config = AutoConfig.from_pretrained("state-spaces/mamba-370m-hf")

# model = AutoModel(config=config)
# model = AutoModel.from_pretrained("EleutherAI/pythia-410m")
model = AutoModel.from_pretrained("state-spaces/mamba-370m-hf")

In [51]:
print(model)
print(f'Params (total): {count_param(model)}')
print(f'Params (trainable): {count_param(model, trainable=True)}')

MambaModel(
  (embeddings): Embedding(50280, 1024)
  (layers): ModuleList(
    (0-47): 48 x MambaBlock(
      (norm): MambaRMSNorm(1024, eps=1e-05)
      (mixer): MambaMixer(
        (conv1d): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)
        (act): SiLU()
        (in_proj): Linear(in_features=1024, out_features=4096, bias=False)
        (x_proj): Linear(in_features=2048, out_features=96, bias=False)
        (dt_proj): Linear(in_features=64, out_features=2048, bias=True)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
      )
    )
  )
  (norm_f): MambaRMSNorm(1024, eps=1e-05)
)
Params (total): 371516416
Params (trainable): 371516416


# Factorize Model

### Apply absolute rank

In [71]:
%%time
fact_model = auto_fact(model, rank=64, groups_out=True, skip_high_groups=True, deepcopy=True)

print(fact_model)
print(f'Params (total): {count_param(fact_model)}')
print(f'Params (trainable): {count_param(fact_model, trainable=True)}')

  f"skipping {type(module)} with groups: {module.groups}, rank: {rank}"
  (module.in_features, module.out_features)


MambaModel(
  (embeddings): Embedding(50280, 1024)
  (layers): ModuleList(
    (0-47): 48 x MambaBlock(
      (norm): MambaRMSNorm(1024, eps=1e-05)
      (mixer): MambaMixer(
        (conv1d): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)
        (act): SiLU()
        (in_proj): LED(
          (led_unit): Sequential(
            (0): Linear(in_features=1024, out_features=64, bias=False)
            (1): Linear(in_features=64, out_features=4096, bias=False)
          )
        )
        (x_proj): LED(
          (led_unit): Sequential(
            (0): Linear(in_features=2048, out_features=64, bias=False)
            (1): Linear(in_features=64, out_features=96, bias=False)
          )
        )
        (dt_proj): Linear(in_features=64, out_features=2048, bias=True)
        (out_proj): LED(
          (led_unit): Sequential(
            (0): Linear(in_features=2048, out_features=64, bias=False)
            (1): Linear(in_features=64, out_features=1024, bias=F

### Apply percentage rank

In [2]:
%%time
pct_fact_model = auto_fact(model, rank=0.25, deepcopy=True)
print(pct_fact_model)
print(f'Params (total): {count_param(pct_fact_model)}')
print(f'Params (trainable): {count_param(pct_fact_model, trainable=True)}')

NameError: name 'auto_fact' is not defined

### Apply factorization only on specific modules

In [11]:
# Only factorize last 6 transformer layers and the pooler layer of the model
factorizable_submodules = list(model.transformer.h[6:])

In [None]:
%%time
fact_model = auto_fact(model, rank=0.2, deepcopy=True, submodules=factorizable_submodules)
print(f'Params (total): {count_param(model)}')
print(f'Params (trainable): {count_param(model, trainable=True)}')

CPU times: user 568 ms, sys: 36 ms, total: 604 ms
Wall time: 87.7 ms


90414336

In [None]:
%%time
fact_model = auto_fact(model, rank=0.2, deepcopy=True, submodules=factorizable_submodules)
count_param(fact_model)

CPU times: user 19 s, sys: 628 ms, total: 19.6 s
Wall time: 2.68 s


90414336

# Speed test on CPU

### Test Inference CPU

In [15]:
%%timeit
with torch.no_grad():
    y = model(torch.zeros(32,256, dtype=torch.long))

2.85 s ± 548 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
%%timeit
with torch.no_grad():
    y = fact_model(torch.zeros(32,256, dtype=torch.long))

2.42 s ± 423 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Test Forward-Backward CPU

In [17]:
%%timeit
y = model(torch.zeros(8,256, dtype=torch.long))
y.logits.sum().backward()

1.83 s ± 223 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
%%timeit
y = fact_model(torch.zeros(8,256, dtype=torch.long))
y.logits.sum().backward()

1.49 s ± 10 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Speed test on GPU

### Move models to GPU

In [26]:
model = model.cuda()
fact_model = fact_model.cuda()

### Test Inference GPU

In [27]:
x = torch.zeros(16,256, dtype=torch.long).cuda()

In [28]:
%%timeit
with torch.no_grad():
    y = model(x)

The slowest run took 8.52 times longer than the fastest. This could mean that an intermediate result is being cached.
65.7 ms ± 44.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [32]:
%%timeit
with torch.no_grad():
    y = fact_model(x)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x768 and 2304x64)

### Test Forward-Backward GPU

In [23]:
x = torch.zeros(8,256, dtype=torch.long).cuda()

In [24]:
%%timeit
y = model(x)
y.logits.sum().backward()

275 ms ± 4.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [25]:
%%timeit
y = fact_model(x)
y.logits.sum().backward()

238 ms ± 2.56 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
