In [3]:
import os, sys
import torch
from transformers import BertModel, BertConfig
from greenformer import auto_fact
from itertools import chain

from os import path
import sys

In [None]:
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())

# Init Model

In [3]:
config = BertConfig.from_pretrained('bert-base-uncased')
model = BertModel(config=config)
model = BertModel.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
count_param(model)

109482240

# Factorize Model

### Apply absolute rank

In [5]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='random', num_iter=20)
count_param(fact_model)

CPU times: user 947 ms, sys: 512 ms, total: 1.46 s
Wall time: 401 ms


66818304

In [6]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='svd', num_iter=20)
count_param(fact_model)

CPU times: user 5min 57s, sys: 7.44 s, total: 6min 4s
Wall time: 23.7 s


66818304

In [7]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='snmf', num_iter=20)
count_param(fact_model)

CPU times: user 5min 33s, sys: 17 s, total: 5min 50s
Wall time: 23.6 s


66818304

### Apply percentage rank

In [8]:
%%time
fact_model = auto_fact(model, rank=0.4, deepcopy=True, solver='random', num_iter=20)
count_param(fact_model)

CPU times: user 1.51 s, sys: 521 ms, total: 2.03 s
Wall time: 428 ms


58052352

In [9]:
%%time
fact_model = auto_fact(model, rank=0.4, deepcopy=True, solver='svd', num_iter=20)
count_param(fact_model)

CPU times: user 4min 33s, sys: 3.34 s, total: 4min 36s
Wall time: 17.7 s


58052352

In [10]:
%%time
fact_model = auto_fact(model, rank=0.4, deepcopy=True, solver='snmf', num_iter=20)
count_param(fact_model)

CPU times: user 5min 9s, sys: 14.1 s, total: 5min 23s
Wall time: 21.5 s


58052352

### Apply factorization only on specific modules

In [11]:
# Only factorize last 6 transformer layers and the pooler layer of the model
factorizable_submodules = list(model.encoder.layer[6:]) + [model.pooler]

In [12]:
%%time
fact_model = auto_fact(model, rank=0.2, deepcopy=True, solver='random', num_iter=20, submodules=factorizable_submodules)
count_param(fact_model)

CPU times: user 1.01 s, sys: 388 ms, total: 1.39 s
Wall time: 197 ms


74965248

In [13]:
%%time
fact_model = auto_fact(model, rank=0.2, deepcopy=True, solver='svd', num_iter=20, submodules=factorizable_submodules)
count_param(fact_model)

CPU times: user 1min 16s, sys: 1.04 s, total: 1min 18s
Wall time: 5 s


74965248

In [14]:
%%time
fact_model = auto_fact(model, rank=0.2, deepcopy=True, solver='snmf', num_iter=20, submodules=factorizable_submodules)
count_param(fact_model)

CPU times: user 1min 55s, sys: 2.34 s, total: 1min 57s
Wall time: 7.72 s


74965248

# Speed test on CPU

### Test Inference CPU

In [15]:
%%timeit
with torch.no_grad():
    y = model(torch.zeros(32,256, dtype=torch.long))

4.15 s ± 31.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
%%timeit
with torch.no_grad():
    y = fact_model(torch.zeros(32,256, dtype=torch.long))

3.01 s ± 37.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Test Forward-Backward CPU

In [17]:
%%timeit
y = model(torch.zeros(8,256, dtype=torch.long))
y.last_hidden_state.sum().backward()

3.33 s ± 158 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
%%timeit
y = fact_model(torch.zeros(8,256, dtype=torch.long))
y.last_hidden_state.sum().backward()

2.41 s ± 144 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Speed test on GPU

### Move models to GPU

In [19]:
model = model.cuda()
fact_model = fact_model.cuda()

### Test Inference GPU

In [20]:
x = torch.zeros(64,256, dtype=torch.long).cuda()

In [21]:
%%timeit
with torch.no_grad():
    y = model(x)

785 ms ± 2.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [22]:
%%timeit
with torch.no_grad():
    y = fact_model(x)

683 ms ± 123 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Test Forward-Backward GPU

In [23]:
x = torch.zeros(16,256, dtype=torch.long).cuda()

In [24]:
%%timeit
y = model(x)
y.last_hidden_state.sum().backward()

672 ms ± 15.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [25]:
%%timeit
y = fact_model(x)
y.last_hidden_state.sum().backward()

518 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
