In [1]:
import os, sys
import torch
from transformers import BertModel, BertConfig
from py_auto_fact import auto_fact
from itertools import chain

from os import path
import sys

In [2]:
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())

# Init Model

In [3]:
config = BertConfig.from_pretrained('bert-base-uncased')
model = BertModel(config=config)
model = BertModel.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
count_param(model)

109482240

# Factorize Model

### Apply absolute rank

In [8]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='random', num_iter=20)
count_param(fact_model)

CPU times: user 568 ms, sys: 40 ms, total: 608 ms
Wall time: 214 ms


66818304

In [9]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='svd', num_iter=20)
count_param(fact_model)

CPU times: user 4min 59s, sys: 6.22 s, total: 5min 5s
Wall time: 33.7 s


66818304

In [10]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='snmf', num_iter=20)
count_param(fact_model)

CPU times: user 3min 14s, sys: 10.6 s, total: 3min 25s
Wall time: 27 s


66818304

### Apply percentage rank

In [11]:
%%time
fact_model = auto_fact(model, rank=0.4, deepcopy=True, solver='random', num_iter=20)
count_param(fact_model)

CPU times: user 532 ms, sys: 28 ms, total: 560 ms
Wall time: 178 ms


58052352

In [12]:
%%time
fact_model = auto_fact(model, rank=0.4, deepcopy=True, solver='svd', num_iter=20)
count_param(fact_model)

CPU times: user 3min 37s, sys: 6.28 s, total: 3min 43s
Wall time: 23.8 s


58052352

In [13]:
%%time
fact_model = auto_fact(model, rank=0.4, deepcopy=True, solver='snmf', num_iter=20)
count_param(fact_model)

CPU times: user 2min 48s, sys: 8.92 s, total: 2min 57s
Wall time: 19.5 s


58052352

# Speed test on CPU

### Test Inference CPU

In [17]:
%%timeit
with torch.no_grad():
    y = model(torch.zeros(32,128, dtype=torch.long))

692 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
%%timeit
with torch.no_grad():
    y = fact_model(torch.zeros(32,128, dtype=torch.long))

420 ms ± 731 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Test Forward-Backward CPU

In [20]:
%%timeit
y = model(torch.zeros(8,128, dtype=torch.long))
y.last_hidden_state.sum().backward()

581 ms ± 4.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [21]:
%%timeit
y = fact_model(torch.zeros(8,128, dtype=torch.long))
y.last_hidden_state.sum().backward()

332 ms ± 815 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Speed test on GPU

### Move models to GPU

In [22]:
model = model.cuda()
fact_model = fact_model.cuda()

### Test Inference GPU

In [23]:
x = torch.zeros(64,128, dtype=torch.long).cuda()

In [24]:
%%timeit
with torch.no_grad():
    y = model(x)

212 ms ± 628 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [25]:
%%timeit
with torch.no_grad():
    y = fact_model(x)

136 ms ± 54.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Test Forward-Backward GPU

In [26]:
x = torch.zeros(16,128, dtype=torch.long).cuda()

In [27]:
%%timeit
y = model(x)
y.last_hidden_state.sum().backward()

162 ms ± 1.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [28]:
%%timeit
y = fact_model(x)
y.last_hidden_state.sum().backward()

80.6 ms ± 571 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
