In [2]:
# Download model
!wget https://storage.googleapis.com/samcah-bucket/deep_fact/best_model.th ./best_model.th

--2021-09-14 09:38:40--  https://storage.googleapis.com/samcah-bucket/deep_fact/best_model.th
Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.31.240, 142.250.66.144, 142.250.204.48, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.31.240|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 433331373 (413M) [application/octet-stream]
Saving to: ‘best_model.th’


2021-09-14 09:38:49 (49.1 MB/s) - ‘best_model.th’ saved [433331373/433331373]

--2021-09-14 09:38:49--  http://./best_model.th
Resolving . (.)... failed: No address associated with hostname.
wget: unable to resolve host address ‘.’
FINISHED --2021-09-14 09:38:49--
Total wall clock time: 9.1s
Downloaded: 1 files, 413M in 8.4s (49.1 MB/s)


In [3]:
import os, sys
import torch
from itertools import chain
from torch.utils.data import Dataset, DataLoader
from transformers import BertConfig, BertForSequenceClassification, BertTokenizer
from transformers.modeling_outputs import SequenceClassifierOutput
from sklearn.metrics import accuracy_score, f1_score
from greenformer import auto_fact
import datasets
from tqdm import tqdm

In [4]:
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())

def metrics_to_string(metric_dict):
    string_list = []
    for key, value in metric_dict.items():
        string_list.append('{}:{:.4f}'.format(key, value))
    return ' '.join(string_list)

# Init Model

In [5]:
config = BertConfig.from_pretrained('bert-base-cased')
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForSequenceClassification(config=config)
model.load_state_dict(torch.load('./best_model.th', map_location='cpu'))

<All keys matched successfully>

In [6]:
count_param(model)

108311810

# Apply partial factorization to BERT model

In [7]:
# Only factorize last four layers of transformer + pooler of the BERT model
factorizable_submodules = list(model.bert.encoder.layer[8:]) + [model.bert.pooler]

In [8]:
%%time
fact_model = auto_fact(model, rank=192, deepcopy=True, solver='svd', num_iter=100, submodules=factorizable_submodules)
print(count_param(fact_model))

90322178
CPU times: user 2min 12s, sys: 2.99 s, total: 2min 15s
Wall time: 14.9 s


# Speed test on CPU

### Test Inference CPU

In [9]:
%%timeit
with torch.no_grad():
    y = model(torch.zeros(16,128, dtype=torch.long))

383 ms ± 40 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
%%timeit
with torch.no_grad():
    y = fact_model(torch.zeros(16,128, dtype=torch.long))

333 ms ± 1.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Test Forward-Backward CPU

In [11]:
%%timeit
y = model(torch.zeros(16,128, dtype=torch.long))
y.logits.sum().backward()

1.2 s ± 3.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%%timeit
y = fact_model(torch.zeros(16,128, dtype=torch.long))
y.logits.sum().backward()

1.08 s ± 4.51 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Speed test on GPU

### Move models to GPU

In [13]:
model = model.cuda()
fact_model = fact_model.cuda()

### Test Inference GPU

In [14]:
x = torch.zeros(16,256, dtype=torch.long).cuda()

In [15]:
%%timeit
with torch.no_grad():
    y = model(x)

126 ms ± 221 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
%%timeit
with torch.no_grad():
    y = fact_model(x)

115 ms ± 145 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Test Forward-Backward GPU

In [17]:
x = torch.zeros(16,256, dtype=torch.long).cuda()

In [18]:
%%timeit
y = model(x)
y.logits.sum().backward()

345 ms ± 395 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
%%timeit
y = fact_model(x)
y.logits.sum().backward()

310 ms ± 657 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Prepare Dataset and DataLoader

In [20]:
# IMDB Dataset
class IMDBDataset(Dataset):
    # Static constant variable
    NUM_LABELS = 2

    def __init__(self, data_split, *args, **kwargs):
        self.data_split = data_split
        if data_split == 'train':
            self.dataset = datasets.load_dataset('imdb')['train']
            self.start_idx = 0
            self.data_len = 22500
        elif data_split == 'validation':
            self.dataset = datasets.load_dataset('imdb')['train']
            self.start_idx = 22500
            self.data_len = 2500
        elif data_split == 'test':
            self.dataset = datasets.load_dataset('imdb')['test']
            self.start_idx = 0
            self.data_len = len(self.dataset)
        else:
            raise ValueError(f'Invalid dataset split: `{data_split}`')

    def __getitem__(self, index):
        label = self.dataset[self.start_idx + index]['label']
        text = self.dataset[self.start_idx + index]['text']
        return text, label

    def __len__(self):
        return self.data_len

In [21]:
class SingleSentenceDataLoader(DataLoader):
    def __init__(self, tokenizer, *args, **kwargs):
        super(SingleSentenceDataLoader, self).__init__(*args, **kwargs)
        self.tokenizer = tokenizer
        self.collate_fn = self._collate_fn

    def _collate_fn(self, batch):
        texts, labels = zip(*batch)
        enc_batch_data = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")

        input_batch = enc_batch_data['input_ids']
        mask_batch = enc_batch_data['attention_mask']
        label_batch = torch.LongTensor(labels)

        return input_batch, mask_batch, label_batch

In [22]:
test_dataset = IMDBDataset('test')
test_loader = SingleSentenceDataLoader(tokenizer=tokenizer, dataset=test_dataset, batch_size=16, num_workers=4, shuffle=False)

Reusing dataset imdb (/home/samuel/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)


# Run Evaluation

In [23]:
# Forward function for sequence classification
def forward_sequence_classification(model, batch_data, device='cpu', **kwargs):
    # Unpack batch data
    if len(batch_data) == 3:
        (input_batch, mask_batch, label_batch) = batch_data
        token_type_batch = None
    elif len(batch_data) == 4:
        (input_batch, mask_batch, token_type_batch, label_batch) = batch_data

    # Prepare input & label
    if device == "cuda":
        input_batch = input_batch.cuda()
        mask_batch = mask_batch.cuda()
        token_type_batch = token_type_batch.cuda() if token_type_batch is not None else None
        label_batch = label_batch.cuda()

    # Forward model
    outputs = model(input_batch, attention_mask=mask_batch, token_type_ids=token_type_batch, labels=label_batch)
    if type(outputs) is SequenceClassifierOutput:
        loss, logits = outputs.loss, outputs.logits
    else:
        if outputs[0] is None:
            logits = outputs
        else:
            loss, logits = outputs[:2]

    # generate prediction & label list
    list_hyp = []
    list_label = []
    hyp = torch.topk(logits, 1)[1]
    for j in range(len(hyp)):
        list_hyp.append(int(hyp[j].item()))
        list_label.append(int(label_batch[j].item()))

    return loss, list_hyp, list_label

# Metric function for calculatting Accuracy and F1
def acc_f1_metrics_fn(list_hyp, list_label):
    metrics = {}
    metrics["ACC"] = accuracy_score(list_label, list_hyp)
    metrics["F1"] = f1_score(list_label, list_hyp, average='macro')
    return metrics

In [24]:
def predict(model, data_loader, forward_fn, metrics_fn, device='cpu'):
    model.eval()
    torch.set_grad_enabled(False)

    total_loss = 0
    list_hyp, list_label, list_seq = [], [], []

    pbar = tqdm(iter(data_loader), leave=True, total=len(data_loader))
    for i, batch_data in enumerate(pbar):
        # batch_seq = batch_data[-1]
        loss, batch_hyp, batch_label = forward_fn(model, batch_data, device=device)

        # Calculate total loss
        test_loss = loss.item()
        total_loss = total_loss + test_loss

        # Calculate evaluation metrics
        list_hyp += batch_hyp
        list_label += batch_label
        # list_seq += batch_seq

        pbar.set_description("TEST LOSS:{:.4f}".format(total_loss/(i+1)))

    metrics = metrics_fn(list_hyp, list_label)
    print("TEST LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))
    return total_loss, metrics, list_hyp, list_label, list_seq

In [25]:
# Original BERT
loss, metrics, _, _, _ = predict(model, test_loader, forward_sequence_classification, acc_f1_metrics_fn, device='cuda')

TEST LOSS:0.4368: 100%|███████████████████████████████| 1563/1563 [07:37<00:00,  3.42it/s]


TEST LOSS:0.4368 ACC:0.9286 F1:0.9286


In [26]:
# Factorized BERT
loss, metrics, _, _, _ = predict(fact_model, test_loader, forward_sequence_classification, acc_f1_metrics_fn, device='cuda')

TEST LOSS:0.2305: 100%|███████████████████████████████| 1563/1563 [06:54<00:00,  3.77it/s]


TEST LOSS:0.2305 ACC:0.9243 F1:0.9243
