In [135]:
%load_ext autoreload
%autoreload 2

from metal.mmtl.task import Task
from metal.mmtl.scorer import Scorer

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [139]:
#########################
# Create Ines's model 
#########################
import os 

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.functional as F
from tqdm import tqdm
from pytorch_pretrained_bert import BertTokenizer, BertModel

from metal.mmtl.dataset import BERTDataset
from metal.end_model import EndModel

batch_size = 8
max_len = 200
weight_decay = 0.01
epochs = 1
lr = 0.001

model = 'bert-base-uncased' # also try bert-base-multilingual-cased (recommended)
src_path = os.path.join(os.environ['GLUEDATA'], 'QNLI/{}.tsv')
dataloaders = {}
for split in ['train', 'test', 'dev']: #, 'train', 'test']:
    label_idx = 3 if split in ['train', 'dev'] else -1
    dataset = BERTDataset(
        src_path.format(split),
        sent1_idx=1,
        sent2_idx=2,
        label_idx=label_idx,
        skip_rows=1,
        label_fn=lambda label: 1 if label=='entailment' else 2 
    )
    dataloaders[split] = dataset.get_dataloader(max_len=max_len, batch_size=batch_size)
    
class BertEncoder(nn.Module):
    def __init__(self):
        super(BertEncoder, self).__init__()
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
        for param in self.bert_model.parameters():
            param.requires_grad = False
        
    def forward(self, data):
        tokens, segments, masks = data
        # TODO: check if we should return all layers or just last hidden representation 
        _, hidden_layer = self.bert_model(input_ids=tokens, token_type_ids=segments, attention_mask=masks)
        return hidden_layer
    
encoder_module = BertEncoder()
end_model = EndModel(
    [768, 2],  # TODO: remove bias
    input_module=encoder_module,
    seed=123,
    skip_head=False,
    input_relu=False,
    input_batchnorm=False,
    verbose=False,
    device=torch.device('cuda'),
)

end_model.train_model(
    train_data=dataloaders['dev'],
    valid_data=dataloaders['dev'],
    l2=weight_decay,
    lr=lr,
    n_epochs=epochs,
    verbose=True,
    checkpoint=False,
    log_unit='epochs', 
    log_train_every=1,
    log_valid_every=1,
    progress_bar=True,
)
    







  0%|          | 0/104743 [00:00<?, ?it/s][A[A[A[A[A[A





  0%|          | 147/104743 [00:00<01:11, 1469.82it/s][A[A[A[A[A[A





  0%|          | 290/104743 [00:00<01:11, 1456.97it/s][A[A[A[A[A[A





  0%|          | 443/104743 [00:00<01:10, 1475.07it/s][A[A[A[A[A[A





  1%|          | 585/104743 [00:00<01:11, 1454.19it/s][A[A[A[A[A[A





  1%|          | 726/104743 [00:00<01:12, 1440.30it/s][A[A[A[A[A[A





  1%|          | 874/104743 [00:00<01:11, 1451.16it/s][A[A[A[A[A[A





  1%|          | 1011/104743 [00:00<01:12, 1424.12it/s][A[A[A[A[A[A





  1%|          | 1156/104743 [00:00<01:12, 1431.54it/s][A[A[A[A[A[A





  1%|          | 1296/104743 [00:00<01:12, 1419.11it/s][A[A[A[A[A[A





  1%|▏         | 1435/104743 [00:01<01:13, 1408.94it/s][A[A[A[A[A[A





  2%|▏         | 1583/104743 [00:01<01:12, 1428.84it/s][A[A[A[A[A[A





  2%|▏         | 1725/104743 [00:01<01:12, 1425.87it/s][A[A[A

 14%|█▍        | 14694/104743 [00:10<01:02, 1447.58it/s][A[A[A[A[A[A





 14%|█▍        | 14840/104743 [00:10<01:02, 1432.14it/s][A[A[A[A[A[A





 14%|█▍        | 14984/104743 [00:10<01:02, 1430.09it/s][A[A[A[A[A[A





 14%|█▍        | 15128/104743 [00:10<01:03, 1418.73it/s][A[A[A[A[A[A





 15%|█▍        | 15278/104743 [00:10<01:02, 1438.57it/s][A[A[A[A[A[A





 15%|█▍        | 15423/104743 [00:10<01:02, 1431.97it/s][A[A[A[A[A[A





 15%|█▍        | 15568/104743 [00:10<01:02, 1435.19it/s][A[A[A[A[A[A





 15%|█▌        | 15715/104743 [00:11<01:01, 1442.47it/s][A[A[A[A[A[A





 15%|█▌        | 15860/104743 [00:11<01:01, 1433.88it/s][A[A[A[A[A[A





 15%|█▌        | 16004/104743 [00:11<01:01, 1431.60it/s][A[A[A[A[A[A





 15%|█▌        | 16148/104743 [00:11<01:02, 1419.80it/s][A[A[A[A[A[A





 16%|█▌        | 16291/104743 [00:11<01:03, 1395.70it/s][A[A[A[A[A[A





 16%|█▌        | 16431/104743 [00:11<01:

 28%|██▊       | 29173/104743 [00:21<01:04, 1163.23it/s][A[A[A[A[A[A





 28%|██▊       | 29300/104743 [00:21<01:03, 1192.18it/s][A[A[A[A[A[A





 28%|██▊       | 29442/104743 [00:21<01:00, 1249.97it/s][A[A[A[A[A[A





 28%|██▊       | 29582/104743 [00:21<00:58, 1289.39it/s][A[A[A[A[A[A





 28%|██▊       | 29725/104743 [00:21<00:56, 1327.58it/s][A[A[A[A[A[A





 29%|██▊       | 29865/104743 [00:21<00:55, 1348.23it/s][A[A[A[A[A[A





 29%|██▊       | 30005/104743 [00:21<00:54, 1362.74it/s][A[A[A[A[A[A





 29%|██▉       | 30143/104743 [00:21<00:55, 1350.58it/s][A[A[A[A[A[A





 29%|██▉       | 30286/104743 [00:21<00:54, 1372.12it/s][A[A[A[A[A[A





 29%|██▉       | 30429/104743 [00:21<00:53, 1385.12it/s][A[A[A[A[A[A





 29%|██▉       | 30569/104743 [00:22<00:53, 1385.14it/s][A[A[A[A[A[A





 29%|██▉       | 30708/104743 [00:22<00:54, 1351.74it/s][A[A[A[A[A[A





 29%|██▉       | 30844/104743 [00:22<00:

 42%|████▏     | 43845/104743 [00:31<00:41, 1450.86it/s][A[A[A[A[A[A





 42%|████▏     | 43991/104743 [00:31<00:42, 1442.26it/s][A[A[A[A[A[A





 42%|████▏     | 44136/104743 [00:31<00:42, 1442.18it/s][A[A[A[A[A[A





 42%|████▏     | 44281/104743 [00:31<00:42, 1412.80it/s][A[A[A[A[A[A





 42%|████▏     | 44423/104743 [00:31<00:43, 1381.52it/s][A[A[A[A[A[A





 43%|████▎     | 44562/104743 [00:32<00:44, 1357.30it/s][A[A[A[A[A[A





 43%|████▎     | 44699/104743 [00:32<00:44, 1353.64it/s][A[A[A[A[A[A





 43%|████▎     | 44838/104743 [00:32<00:43, 1364.19it/s][A[A[A[A[A[A





 43%|████▎     | 44982/104743 [00:32<00:43, 1384.76it/s][A[A[A[A[A[A





 43%|████▎     | 45128/104743 [00:32<00:42, 1404.62it/s][A[A[A[A[A[A





 43%|████▎     | 45269/104743 [00:32<00:42, 1405.54it/s][A[A[A[A[A[A





 43%|████▎     | 45410/104743 [00:32<00:42, 1381.90it/s][A[A[A[A[A[A





 43%|████▎     | 45562/104743 [00:32<00:

 56%|█████▌    | 58542/104743 [00:41<00:32, 1428.44it/s][A[A[A[A[A[A





 56%|█████▌    | 58686/104743 [00:42<00:32, 1410.36it/s][A[A[A[A[A[A





 56%|█████▌    | 58828/104743 [00:42<00:32, 1399.07it/s][A[A[A[A[A[A





 56%|█████▋    | 58978/104743 [00:42<00:32, 1427.60it/s][A[A[A[A[A[A





 56%|█████▋    | 59122/104743 [00:42<00:32, 1422.93it/s][A[A[A[A[A[A





 57%|█████▋    | 59265/104743 [00:42<00:32, 1420.37it/s][A[A[A[A[A[A





 57%|█████▋    | 59412/104743 [00:42<00:31, 1434.65it/s][A[A[A[A[A[A





 57%|█████▋    | 59559/104743 [00:42<00:31, 1444.90it/s][A[A[A[A[A[A





 57%|█████▋    | 59704/104743 [00:42<00:31, 1431.33it/s][A[A[A[A[A[A





 57%|█████▋    | 59848/104743 [00:42<00:31, 1427.47it/s][A[A[A[A[A[A





 57%|█████▋    | 59993/104743 [00:42<00:31, 1430.35it/s][A[A[A[A[A[A





 57%|█████▋    | 60137/104743 [00:43<00:38, 1161.90it/s][A[A[A[A[A[A





 58%|█████▊    | 60262/104743 [00:43<00:

 70%|██████▉   | 72817/104743 [00:52<00:22, 1416.06it/s][A[A[A[A[A[A





 70%|██████▉   | 72960/104743 [00:52<00:22, 1415.02it/s][A[A[A[A[A[A





 70%|██████▉   | 73102/104743 [00:53<00:22, 1390.86it/s][A[A[A[A[A[A





 70%|██████▉   | 73246/104743 [00:53<00:22, 1403.44it/s][A[A[A[A[A[A





 70%|███████   | 73389/104743 [00:53<00:22, 1410.11it/s][A[A[A[A[A[A





 70%|███████   | 73531/104743 [00:53<00:22, 1406.20it/s][A[A[A[A[A[A





 70%|███████   | 73672/104743 [00:53<00:22, 1398.21it/s][A[A[A[A[A[A





 70%|███████   | 73813/104743 [00:53<00:22, 1401.61it/s][A[A[A[A[A[A





 71%|███████   | 73954/104743 [00:53<00:22, 1386.00it/s][A[A[A[A[A[A





 71%|███████   | 74095/104743 [00:53<00:22, 1391.47it/s][A[A[A[A[A[A





 71%|███████   | 74236/104743 [00:53<00:21, 1396.74it/s][A[A[A[A[A[A





 71%|███████   | 74376/104743 [00:53<00:21, 1386.32it/s][A[A[A[A[A[A





 71%|███████   | 74515/104743 [00:54<00:

 83%|████████▎ | 87407/104743 [01:03<00:12, 1426.51it/s][A[A[A[A[A[A





 84%|████████▎ | 87550/104743 [01:03<00:12, 1406.55it/s][A[A[A[A[A[A





 84%|████████▎ | 87691/104743 [01:03<00:12, 1382.72it/s][A[A[A[A[A[A





 84%|████████▍ | 87835/104743 [01:03<00:12, 1398.17it/s][A[A[A[A[A[A





 84%|████████▍ | 87977/104743 [01:03<00:11, 1404.38it/s][A[A[A[A[A[A





 84%|████████▍ | 88118/104743 [01:03<00:11, 1404.75it/s][A[A[A[A[A[A





 84%|████████▍ | 88259/104743 [01:03<00:11, 1398.12it/s][A[A[A[A[A[A





 84%|████████▍ | 88399/104743 [01:03<00:11, 1393.13it/s][A[A[A[A[A[A





 85%|████████▍ | 88539/104743 [01:04<00:13, 1173.11it/s][A[A[A[A[A[A





 85%|████████▍ | 88680/104743 [01:04<00:13, 1230.88it/s][A[A[A[A[A[A





 85%|████████▍ | 88821/104743 [01:04<00:12, 1278.50it/s][A[A[A[A[A[A





 85%|████████▍ | 88958/104743 [01:04<00:12, 1301.00it/s][A[A[A[A[A[A





 85%|████████▌ | 89104/104743 [01:04<00:

 97%|█████████▋| 101973/104743 [01:13<00:02, 1380.19it/s][A[A[A[A[A[A





 97%|█████████▋| 102122/104743 [01:13<00:01, 1410.68it/s][A[A[A[A[A[A





 98%|█████████▊| 102269/104743 [01:13<00:01, 1427.67it/s][A[A[A[A[A[A





 98%|█████████▊| 102413/104743 [01:14<00:01, 1417.55it/s][A[A[A[A[A[A





 98%|█████████▊| 102555/104743 [01:14<00:01, 1392.31it/s][A[A[A[A[A[A





 98%|█████████▊| 102695/104743 [01:14<00:01, 1369.24it/s][A[A[A[A[A[A





 98%|█████████▊| 102833/104743 [01:14<00:01, 1349.60it/s][A[A[A[A[A[A





 98%|█████████▊| 102969/104743 [01:14<00:01, 1334.56it/s][A[A[A[A[A[A





 98%|█████████▊| 103107/104743 [01:14<00:01, 1347.56it/s][A[A[A[A[A[A





 99%|█████████▊| 103247/104743 [01:14<00:01, 1360.78it/s][A[A[A[A[A[A





 99%|█████████▊| 103393/104743 [01:14<00:00, 1385.53it/s][A[A[A[A[A[A





 99%|█████████▉| 103542/104743 [01:14<00:00, 1414.78it/s][A[A[A[A[A[A





 99%|█████████▉| 103684/1047

Using GPU...


HBox(children=(IntProgress(value=0, max=683), HTML(value='')))

Finished Training
Accuracy: 0.685
        y=1    y=2   
 l=1   2003    699   
 l=2   1023   1738   


In [140]:
# Test 

def custom_eval_function(model, dataloader):
    print("Running custom_eval_function")
    return {"custom_metric" : 0}

# Create a scorer (standard_metrics are broken)
dummy_scorer = Scorer(standard_metrics=[], custom_metric_fns=[custom_eval_function])

# Create task with scorer
data_loaders = [dataloaders[x] for x in ["train", "test", "dev"]]
foo_task = Task(name="foo_task", 
                input_module=encoder_module,
                head_module=end_model,
                data_loaders=data_loaders, scorers=[scorer])

# Call scorer on model / task / etc
scorer(foo_task, end_model, data_loaders[-1], split_name="test_scorer")

Running custom_eval_function


{'test_scorer/custom_metric': 0}

In [144]:
# Test standard loss function

# Create a scorer (standard_metrics are broken)
dummy_loss_scorer = Scorer(standard_metrics=["f1"])

# Create task with scorer
data_loaders = [dataloaders[x] for x in ["train", "test", "dev"]]
foo_task = Task(name="foo_task", 
                input_module=encoder_module,
                head_module=end_model,
                data_loaders=data_loaders, scorers=[dummy_loss_scorer])

# Call scorer on model / task / etc
dummy_loss_scorer(foo_task, end_model, data_loaders[-1], split_name="test_scorer")

Batch 0 of 683
<class 'torch.Tensor'>


RuntimeError: Expected object of backend CUDA but got backend CPU for argument #3 'index'

In [127]:
# Test head_output optimization