In [1]:
import os, sys
import torch
from transformers import GPT2LMHeadModel, GPT2Config
from greenformer import auto_fact
from itertools import chain

from os import path
import sys

In [2]:
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())

# Init Model

In [3]:
config = GPT2Config.from_pretrained('gpt2')
model = GPT2LMHeadModel(config=config)
model = GPT2LMHeadModel.from_pretrained('gpt2')

In [4]:
count_param(model)

124439808

# Factorize Model

### Apply absolute rank

In [5]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='random', num_iter=20)
count_param(fact_model)

CPU times: user 708 ms, sys: 220 ms, total: 928 ms
Wall time: 229 ms


77253888

In [6]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='svd', num_iter=20)
count_param(fact_model)

CPU times: user 2min 7s, sys: 4.74 s, total: 2min 12s
Wall time: 14.4 s


77253888

In [7]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='snmf', num_iter=20)
count_param(fact_model)

CPU times: user 1min 9s, sys: 4.49 s, total: 1min 14s
Wall time: 9.17 s


77253888

### Apply percentage rank

In [8]:
%%time
fact_model = auto_fact(model, rank=0.4, deepcopy=True, solver='random', num_iter=20)
count_param(fact_model)

CPU times: user 648 ms, sys: 8 ms, total: 656 ms
Wall time: 183 ms


73383168

In [9]:
%%time
fact_model = auto_fact(model, rank=0.4, deepcopy=True, solver='svd', num_iter=20)
count_param(fact_model)

CPU times: user 1min 48s, sys: 3.65 s, total: 1min 52s
Wall time: 11.7 s


73383168

In [10]:
%%time
fact_model = auto_fact(model, rank=0.4, deepcopy=True, solver='snmf', num_iter=20)
count_param(fact_model)

CPU times: user 1min, sys: 2.94 s, total: 1min 3s
Wall time: 7.56 s


73383168

### Apply factorization only on specific modules

In [11]:
# Only factorize last 6 transformer layers and the pooler layer of the model
factorizable_submodules = list(model.transformer.h[6:])

In [12]:
%%time
fact_model = auto_fact(model, rank=0.2, deepcopy=True, solver='random', num_iter=20, submodules=factorizable_submodules)
count_param(fact_model)

CPU times: user 568 ms, sys: 36 ms, total: 604 ms
Wall time: 87.7 ms


90414336

In [13]:
%%time
fact_model = auto_fact(model, rank=0.2, deepcopy=True, solver='svd', num_iter=20, submodules=factorizable_submodules)
count_param(fact_model)

CPU times: user 24.2 s, sys: 948 ms, total: 25.1 s
Wall time: 2.78 s


90414336

In [14]:
%%time
fact_model = auto_fact(model, rank=0.2, deepcopy=True, solver='snmf', num_iter=20, submodules=factorizable_submodules)
count_param(fact_model)

CPU times: user 19 s, sys: 628 ms, total: 19.6 s
Wall time: 2.68 s


90414336

# Speed test on CPU

### Test Inference CPU

In [15]:
%%timeit
with torch.no_grad():
    y = model(torch.zeros(32,256, dtype=torch.long))

2.85 s ± 548 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
%%timeit
with torch.no_grad():
    y = fact_model(torch.zeros(32,256, dtype=torch.long))

2.42 s ± 423 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Test Forward-Backward CPU

In [17]:
%%timeit
y = model(torch.zeros(8,256, dtype=torch.long))
y.logits.sum().backward()

1.83 s ± 223 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
%%timeit
y = fact_model(torch.zeros(8,256, dtype=torch.long))
y.logits.sum().backward()

1.49 s ± 10 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Speed test on GPU

### Move models to GPU

In [19]:
model = model.cuda()
fact_model = fact_model.cuda()

### Test Inference GPU

In [20]:
x = torch.zeros(16,256, dtype=torch.long).cuda()

In [21]:
%%timeit
with torch.no_grad():
    y = model(x)

180 ms ± 412 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [22]:
%%timeit
with torch.no_grad():
    y = fact_model(x)

155 ms ± 273 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Test Forward-Backward GPU

In [23]:
x = torch.zeros(8,256, dtype=torch.long).cuda()

In [24]:
%%timeit
y = model(x)
y.logits.sum().backward()

275 ms ± 4.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [25]:
%%timeit
y = fact_model(x)
y.logits.sum().backward()

238 ms ± 2.56 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
