In [1]:
import os 

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.functional as F
from tqdm import tqdm
from pytorch_pretrained_bert import BertTokenizer, BertModel

from metal.mmtl.dataset import BERTDataset
from metal.end_model import EndModel

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


### Config

In [7]:
batch_size = 32
max_len = 250
weight_decay = 0.01
epochs = 5
lr = 0.01

### Preprocess data

In [16]:
model = 'bert-base-uncased' # also try bert-base-multilingual-cased (recommended)
src_path = os.path.join(os.environ['GLUEDATA'], 'QNLI/{}.tsv')
dataloaders = {}
for split in ['test', 'dev']: #, 'train', 'test']:
    label_idx = 3 if split in ['train', 'dev'] else -1
    dataset = BERTDataset(
        src_path.format(split),
        sent1_idx=1,
        sent2_idx=2,
        label_idx=label_idx,
        skip_rows=1,
        label_fn=lambda label: 1 if label=='entailment' else 2 
    )
    dataloaders[split] = dataset.get_dataloader(batch_size)

100%|██████████| 5463/5463 [00:09<00:00, 586.67it/s]
100%|██████████| 5463/5463 [00:07<00:00, 691.19it/s]


### Metal Model 

In [9]:
class BertEncoder(nn.Module):
    def __init__(self):
        super(BertEncoder, self).__init__()
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
        
    def forward(self, data):
        tokens, segments, masks = data
        # TODO: check if we should return all layers or just last hidden representation 
        _, hidden_layer = self.bert_model(tokens, segments, masks)
        return hidden_layer

In [10]:
%%time 
encoder_module = BertEncoder()
end_model = EndModel(
    [768, 2],
    input_module=encoder_module,
    seed=123,
    skip_head=False,
    input_relu=False,
    input_batchnorm=False,
    verbose=False,
    # device=''
)

CPU times: user 19.5 s, sys: 4.26 s, total: 23.8 s
Wall time: 18.1 s


### Train model

In [13]:
# TODO: if batch size is 1 then assertion error with metal
end_model.train_model(
    train_data=dataloaders['dev'],
    valid_data=dataloaders['dev'],
    l2=weight_decay,
    lr=lr,
    n_epochs=epochs,
    validation_metric="accuracy",
    verbose=True,
    log_train_every=1,
    log_unit='examples', 
    progress_bar=True
)

Exception ignored in: <generator object tqdm_notebook.__iter__ at 0x7f0bea638570>
Traceback (most recent call last):
  File "/dfs/scratch0/chami/miniconda3/envs/metal/lib/python3.6/site-packages/tqdm/_tqdm_notebook.py", line 226, in __iter__
    self.sp(bar_style='danger')
AttributeError: 'tqdm_notebook' object has no attribute 'sp'


KeyboardInterrupt: 

### Test model

In [None]:
# Test end model
end_model.score(dataset['dev'].get_dataloader(max_len, batch_size), metric=["accuracy", "precision", "recall", "f1"])