In [97]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [98]:
import os, sys
import torch
from transformers import BertModel, BertConfig
# from py_auto_fact import auto_fact
from itertools import chain

from os import path
import sys
sys.path.append(path.abspath('../../py_auto_fact'))
from src.py_auto_fact.auto_fact import auto_fact

In [99]:
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())

# Init Model

In [100]:
config = BertConfig.from_pretrained('bert-base-uncased')
model = BertModel(config=config)
model = BertModel.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [101]:
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

# Factorize Model

In [102]:
count_param(model)

109482240

### Apply absolute rank

In [103]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='random')
count_param(fact_model)

CPU times: user 502 ms, sys: 59.9 ms, total: 562 ms
Wall time: 292 ms


66818304

In [104]:
fact_model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): LED(
              (led_unit): Sequential(
                (0): Linear(in_features=768, out_features=256, bias=False)
                (1): Linear(in_features=256, out_features=768, bias=True)
              )
            )
            (key): LED(
              (led_unit): Sequential(
                (0): Linear(in_features=768, out_features=256, bias=False)
                (1): Linear(in_features=256, out_features=768, bias=True)
              )
            )
            (value): LED(
              (led_unit): S

In [105]:
fact_fact_model = auto_fact(fact_model, rank=256, deepcopy=True, solver='random')

In [106]:
factorizable_module_list = [
    model.encoder.layer[6],
    model.encoder.layer[8],
    model.encoder.layer[9],
    model.encoder.layer[11],
    model.pooler]

fact_model2 = auto_fact(model, rank=0.5, deepcopy=True, solver='random', num_iter=100, factorizable_module_list=factorizable_module_list)

fact_model2

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [16]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='svd', eigen_threshold=0.6)
count_param(fact_model)

CPU times: user 2min 6s, sys: 1.07 s, total: 2min 7s
Wall time: 23.2 s


69570816

In [8]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='snmf', eigen_threshold=0.6)
count_param(fact_model)

CPU times: user 1min 32s, sys: 2.13 s, total: 1min 34s
Wall time: 11 s


69570816

In [9]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='random', num_iter=50, eigen_threshold=0.0)
count_param(fact_model)

CPU times: user 53 s, sys: 380 ms, total: 53.4 s
Wall time: 5.35 s


66818304

In [10]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='svd', num_iter=50, eigen_threshold=0.0)
count_param(fact_model)

CPU times: user 6min 16s, sys: 8.02 s, total: 6min 24s
Wall time: 40.5 s


66818304

In [11]:
%%time
fact_model = auto_fact(model, rank=256, deepcopy=True, solver='snmf', num_iter=50, eigen_threshold=0.0)
count_param(fact_model)

CPU times: user 3min 50s, sys: 2.94 s, total: 3min 53s
Wall time: 28.7 s


66818304

### Apply percentage rank

In [12]:
%%time
fact_model = auto_fact(model, rank=0.4, deepcopy=True, solver='random', num_iter=50)
count_param(fact_model)

CPU times: user 672 ms, sys: 8 ms, total: 680 ms
Wall time: 248 ms


75356928

In [13]:
%%time
fact_model = auto_fact(model, rank=0.4, deepcopy=True, solver='svd', num_iter=50)
count_param(fact_model)

CPU times: user 6min 43s, sys: 6.01 s, total: 6min 49s
Wall time: 41.1 s


75356928

In [14]:
%%time
fact_model = auto_fact(model, rank=0.4, deepcopy=True, solver='snmf', num_iter=50)
count_param(fact_model)

CPU times: user 3min 36s, sys: 4.05 s, total: 3min 40s
Wall time: 28.4 s


75356928

In [15]:
%%time
fact_model = auto_fact(model, rank=0.2, deepcopy=True, solver='random', num_iter=50)
count_param(fact_model)

CPU times: user 528 ms, sys: 20 ms, total: 548 ms
Wall time: 164 ms


49573632

In [16]:
%%time
fact_model = auto_fact(model, rank=0.2, deepcopy=True, solver='svd', num_iter=50)
count_param(fact_model)

CPU times: user 2min 28s, sys: 1.29 s, total: 2min 29s
Wall time: 15 s


49573632

In [17]:
%%time
fact_model = auto_fact(model, rank=0.2, deepcopy=True, solver='snmf', num_iter=50)
count_param(fact_model)

CPU times: user 2min 24s, sys: 2.59 s, total: 2min 27s
Wall time: 15.7 s


49573632

# Test on CPU

### Test Inference CPU

In [18]:
%%timeit
with torch.no_grad():
    y = model(torch.zeros(32,128, dtype=torch.long))

690 ms ± 2.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
690 ms ± 2.76 ms
452 ms ± 0.88 ms

In [19]:
%%timeit
with torch.no_grad():
    y = fact_model(torch.zeros(32,128, dtype=torch.long))

452 ms ± 882 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Test Forward-Backward CPU

In [20]:
%%timeit
y = model(torch.zeros(8,128, dtype=torch.long))
y.last_hidden_state.sum().backward()

613 ms ± 13.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [21]:
%%timeit
y = fact_model(torch.zeros(8,128, dtype=torch.long))
y.last_hidden_state.sum().backward()

395 ms ± 1.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Test on GPU

### Move models to GPU

In [22]:
model = model.cuda()
fact_model = fact_model.cuda()

### Test Inference GPU

In [23]:
x = torch.zeros(64,128, dtype=torch.long).cuda()

In [24]:
%%timeit
with torch.no_grad():
    y = model(x)

210 ms ± 555 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [25]:
%%timeit
with torch.no_grad():
    y = fact_model(x)

159 ms ± 44.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Test Forward-Backward GPU

In [26]:
x = torch.zeros(16,128, dtype=torch.long).cuda()

In [27]:
%%timeit
y = model(x)
y.last_hidden_state.sum().backward()

158 ms ± 331 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [28]:
%%timeit
y = fact_model(x)
y.last_hidden_state.sum().backward()

94.6 ms ± 85.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
