In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os

import numpy as np
import torch.nn as nn

from pytorch_pretrained_bert import BertTokenizer
from metal.mmtl.metal_model import MetalModel
from metal.mmtl.glue.glue_tasks import create_tasks
from metal.mmtl.trainer import MultitaskTrainer

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


### Config

In [3]:
bert_model = 'bert-base-uncased'
bert_model_output_shape = 768
max_len = 200
batch_size = 16
split_prop = 0.8
max_datapoints = 10

Create task object

In [4]:
tasks = create_tasks(
    ['QNLIR'],
    bert_model,
    split_prop=None,
    max_len=max_len,
    dl_kwargs={'batch_size': batch_size},
    bert_kwargs={},
    bert_output_dim=bert_model_output_shape,
    max_datapoints=max_datapoints,
)

Loading QNLIR Dataset











Print some examples from dataloader. The dataloader returns pairs ((Q, S1), (Q, S2)) where one pair is an entailement for question Q and the other is not. 

In [5]:
tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True)
xs = []
ys = []
for x, y in tasks[0].data_loaders['test']:
    break
for tokens in x[0][:4]:
    print(' '.join(tokenizer.convert_ids_to_tokens(tokens.numpy())).replace('[PAD]', ''))

[CLS] what organization is devoted to jihad against israel ? [SEP] for some decades prior to the first palestine int ##if ##ada in 1987 , the muslim brotherhood in palestine took a " qui ##es ##cent " stance towards israel , focusing on preaching , education and social services , and benefit ##ing from israel ' s " ind ##ul ##gence " to build up a network of mosques and charitable organizations . [SEP]    
[CLS] in what century was the ya ##rrow - sc ##hli ##ck - tweed ##y balancing system used ? [SEP] in the late 19th century , the ya ##rrow - sc ##hli ##ck - tweed ##y balancing ' system ' was used on some marine triple expansion engines . [SEP]                            
[CLS] the largest brand of what store in the uk is located in kingston park ? [SEP] close to newcastle , the largest indoor shopping centre in europe , the metro ##cent ##re , is located in gates ##head . [SEP]                                     
[CLS] what does the ip ##cc rely on for research ? [SEP] in principle

In [6]:
y

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

Create and train MetalModel for QNLI ranking task.

In [6]:
model = MetalModel(tasks, verbose=False)

In [7]:
trainer = MultitaskTrainer()
trainer.train_model(
    model,
    tasks,
    lr=1e-5,
    n_epochs=1,
    progress_bar=True,
    checkpoint_metric='QNLIR/valid/accuracy',
    checkpoint_metric_mode="max",
    checkpoint_dir=f"{os.environ['METALHOME']}/checkpoints/qnli_single")

Beginning train loop.
Expecting a total of approximately 83680 examples and 5230 batches per epoch from 1 tasks.


[1.0 epo]: TRAIN:[model/loss=4.78e-01, model/lr=3.82e-09] VALID:[QNLIR/accuracy=7.33e-01, QNLIR/f1=7.50e-01, QNLIR/acc_f1=7.42e-01]

Finished Training
{'QNLIR/test/acc_f1': 0.7238095238095238,
 'QNLIR/test/accuracy': 0.7333333333333333,
 'QNLIR/test/f1': 0.7142857142857143}


In [15]:
for (X, Y) in tasks[0].data_loaders['valid']:
    print(model(X, ['QNLIR']))
    print(model.calculate_loss(X, Y.cuda(), ['QNLIR']))    
    print(model.calculate_probs(X, ['QNLIR']))    
    break

{'QNLIR': tensor([[-1.4154],
        [-1.3964],
        [-1.4229],
        [-1.3765],
        [-1.4278],
        [-1.4216],
        [-1.4259],
        [-1.4023],
        [-1.4292],
        [-1.4186],
        [-1.4167],
        [-1.3918],
        [ 0.4138],
        [-1.4115],
        [-1.4272],
        [-1.4230]], device='cuda:0', grad_fn=<MmBackward>)}
{'QNLIR': tensor(0.7219, device='cuda:0', grad_fn=<NegBackward>)}
{'QNLIR': tensor([[0.1954],
        [0.1984],
        [0.1942],
        [0.2016],
        [0.1934],
        [0.1944],
        [0.1937],
        [0.1974],
        [0.1932],
        [0.1949],
        [0.1952],
        [0.1991],
        [0.6020],
        [0.1960],
        [0.1935],
        [0.1942]], device='cuda:0')}


In [21]:
model.score(tasks[0], split='test', metrics=['accuracy'], verbose=True)

[0.4946000366099213]

In [None]:
model