In [1]:
import os 

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.functional as F
from tqdm import tqdm
from pytorch_pretrained_bert import BertTokenizer, BertModel

from metal.mmtl.dataset import BERTDataset
from metal.end_model import EndModel

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


### Config

In [2]:
batch_size = 16
max_len = 512
weight_decay = 0.01
epochs = 3
lr = 5e-5

### Preprocess data

In [3]:
model = 'bert-base-uncased' # also try bert-base-multilingual-cased (recommended)
src_path = os.path.join(os.environ['GLUEDATA'], 'QNLI/{}.tsv')
dataloaders = {}
for split in ['test', 'dev', 'train']:
    label_idx = 3 if split in ['train', 'dev'] else -1
    dataset = BERTDataset(
        src_path.format(split),
        sent1_idx=1,
        sent2_idx=2,
        label_idx=label_idx,
        skip_rows=1,
        label_fn=lambda label: 1 if label=='entailment' else 2,
        max_len=max_len
    )
    dataloaders[split] = dataset.get_dataloader(batch_size=batch_size)

100%|██████████| 5463/5463 [00:06<00:00, 803.99it/s]
100%|██████████| 5463/5463 [00:06<00:00, 794.02it/s]
100%|██████████| 104743/104743 [02:09<00:00, 806.11it/s]


### Metal Model 

In [4]:
class BertEncoder(nn.Module):
    def __init__(self):
        super(BertEncoder, self).__init__()
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
        #for param in self.bert_model.parameters():
        #    param.requires_grad = False
        
    def forward(self, data):
        tokens, segments, masks = data
        # TODO: check if we should return all layers or just last hidden representation 
        _, hidden_layer = self.bert_model(input_ids=tokens, token_type_ids=segments, attention_mask=masks)
        return hidden_layer

In [None]:
%%time 
encoder_module = BertEncoder()
end_model = EndModel(
    [768, 2],  # TODO: remove bias
    input_module=encoder_module,
    seed=123,
    skip_head=False,
    input_relu=False,
    input_batchnorm=False,
    verbose=False,
    device=torch.device('cuda'),
)

### Train model

In [9]:
# TODO: if batch size is 1 then assertion error with metal
end_model.train_model(
    train_data=dataloaders['train'],
    valid_data=dataloaders['dev'],
    l2=weight_decay,
    lr=lr,
    n_epochs=epochs,
    verbose=True,
    checkpoint=False,
    log_unit='epochs', 
    log_train_every=1,
    log_valid_every=1,
    progress_bar=True,
)
"""
    log_train_metrics=["train/loss"],
    log_valid_metrics=["valid/accuracy"],
    checkpoint_metric_mode='max',
    checkpoint_metric="valid/accuracy",
)
"""

Using GPU...








[1 epo]: TRAIN:[loss=0.603] VALID:[accuracy=0.686]

Finished Training
Accuracy: 0.694
        y=1    y=2   
 l=1   2099    603   
 l=2   1066   1695   


'\n    log_train_metrics=["train/loss"],\n    log_valid_metrics=["valid/accuracy"],\n    checkpoint_metric_mode=\'max\',\n    checkpoint_metric="valid/accuracy",\n)\n'

### Test model

In [25]:
# Test end model
end_model.score(dataloaders['dev'], metric=["accuracy", "precision", "recall", "f1"])

Accuracy: 0.693
Precision: 0.652
Recall: 0.816
F1: 0.725
        y=1    y=2   
 l=1   2205    497   
 l=2   1179   1582   


[0.6932088596009519, 0.651595744680851, 0.8160621761658031, 0.7246138678935262]

In [26]:
data = []
for i, x in enumerate(dataloaders['dev']):
    if i>1: 
        break
    (tokens, segments, masks), labels = x
    tokens = tokens.cuda()
    segments = segments.cuda()
    masks = masks.cuda()
    preds = end_model.predict_proba((tokens, segments, masks))