In [1]:
import torch
from transformers import BertModel, BertConfig
from py_auto_fact import auto_fact

# Init Model

In [2]:
config = BertConfig.from_pretrained('bert-base-uncased')
model = BertModel(config=config)
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

# Factorize Model

In [3]:
%%time
fact_model = auto_fact(model, rank=128, deepcopy=True, solver='random')
fact_model

CPU times: user 1.11 s, sys: 849 ms, total: 1.96 s
Wall time: 292 ms


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): LED(
              (led_unit): Sequential(
                (0): Linear(in_features=768, out_features=128, bias=False)
                (1): Linear(in_features=128, out_features=768, bias=True)
              )
            )
            (key): LED(
              (led_unit): Sequential(
                (0): Linear(in_features=768, out_features=128, bias=False)
                (1): Linear(in_features=128, out_features=768, bias=True)
              )
            )
            (value): LED(
              (led_unit): S

In [4]:
%%time
fact_model = auto_fact(model, rank=128, deepcopy=True, solver='svd')
fact_model

CPU times: user 1min 39s, sys: 7.01 s, total: 1min 46s
Wall time: 6.99 s


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): LED(
              (led_unit): Sequential(
                (0): Linear(in_features=768, out_features=128, bias=False)
                (1): Linear(in_features=128, out_features=768, bias=True)
              )
            )
            (key): LED(
              (led_unit): Sequential(
                (0): Linear(in_features=768, out_features=128, bias=False)
                (1): Linear(in_features=128, out_features=768, bias=True)
              )
            )
            (value): LED(
              (led_unit): S

In [5]:
%%time
fact_model = auto_fact(model, rank=128, deepcopy=True, solver='snmf')
fact_model

CPU times: user 46min 44s, sys: 25min 25s, total: 1h 12min 10s
Wall time: 2min 31s


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): LED(
              (led_unit): Sequential(
                (0): Linear(in_features=768, out_features=128, bias=False)
                (1): Linear(in_features=128, out_features=768, bias=True)
              )
            )
            (key): LED(
              (led_unit): Sequential(
                (0): Linear(in_features=768, out_features=128, bias=False)
                (1): Linear(in_features=128, out_features=768, bias=True)
              )
            )
            (value): LED(
              (led_unit): S

# Test on CPU

### Test Inference CPU

In [4]:
%%timeit
with torch.no_grad():
    y = model(torch.zeros(32,128, dtype=torch.long))

1.76 s ± 233 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
%%timeit
with torch.no_grad():
    y = fact_model(torch.zeros(32,128, dtype=torch.long))

854 ms ± 33.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Test Forward-Backward CPU

In [6]:
%%timeit
y = model(torch.zeros(8,128, dtype=torch.long))
y.last_hidden_state.sum().backward()

1.32 s ± 203 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
%%timeit
y = fact_model(torch.zeros(8,128, dtype=torch.long))
y.last_hidden_state.sum().backward()

599 ms ± 55.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Test on GPU

### Move models to GPU

In [8]:
model = model.cuda()
fact_model = fact_model.cuda()

### Test Inference GPU

In [9]:
x = torch.zeros(32,128, dtype=torch.long).cuda()

In [10]:
%%timeit
with torch.no_grad():
    y = model(x)

113 ms ± 251 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
%%timeit
with torch.no_grad():
    y = fact_model(x)

61.3 ms ± 32.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Test Forward-Backward GPU

In [12]:
x = torch.zeros(8,128, dtype=torch.long).cuda()

In [13]:
%%timeit
y = model(x)
y.last_hidden_state.sum().backward()

92.3 ms ± 252 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [14]:
%%timeit
y = fact_model(x)
y.last_hidden_state.sum().backward()

58.1 ms ± 200 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
