In [1]:
import os 

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.functional as F
from tqdm import tqdm
from pytorch_pretrained_bert import BertTokenizer, BertModel

from metal.mmtl.dataset import QNLIDataset
from metal.mmtl.modules import BertEncoder
from metal.end_model import EndModel

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


### Config

In [6]:
batch_size = 8
max_len = 256
weight_decay = 0.01
epochs = 3
lr = 1e-5
bert_model = 'bert-base-uncased' # also try bert-base-multilingual-cased (recommended)

### Preprocess data

In [3]:
dataloaders = {}
for split in ['test', 'dev']: # 'train'
    dataset = QNLIDataset(
        split=split,
        bert_model=bert_model,
        max_len=max_len
    )
    dataloaders[split] = dataset.get_dataloader(batch_size=batch_size, shuffle=True)

100%|██████████| 5463/5463 [00:06<00:00, 831.01it/s]
100%|██████████| 5463/5463 [00:06<00:00, 832.56it/s]
100%|██████████| 104743/104743 [02:05<00:00, 837.34it/s]


### Metal Model 

In [7]:
%%time   
encoder_module = BertEncoder(bert_model)
end_model = EndModel(
    [768, 2],  # TODO: remove bias
    input_module=encoder_module,
    seed=123,
    skip_head=False,
    input_relu=False,
    input_batchnorm=False,
    verbose=False,
    device=torch.device('cuda'),
)

CPU times: user 19.5 s, sys: 4.51 s, total: 24.1 s
Wall time: 18.2 s


### Train model

In [8]:
# TODO: if batch size is 1 then assertion error with metal
end_model.train_model(
    train_data=dataloaders['train'],
    #valid_data=dataloaders['dev'],
    l2=weight_decay,
    lr=lr,
    n_epochs=1,
    verbose=True,
    checkpoint=False,
    log_unit='epochs', 
    log_train_every=1,
    log_valid_every=1,
    progress_bar=True,
)

Using GPU...


Finished Training


### Test model

In [9]:
# Test end model
end_model.score(dataloaders['dev'], metric=["accuracy", "precision", "recall", "f1"])

Accuracy: 0.758
Precision: 0.701
Recall: 0.890
F1: 0.784
        y=1    y=2   
 l=1   2404    298   
 l=2   1026   1735   


[0.7576423210690097,
 0.7008746355685131,
 0.8897113249444856,
 0.7840834964122635]