In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from metal.end_model import EndModel
from metal.mmtl.modules import BertEncoder
from metal.mmtl.metal_model import MetalModel
from metal.mmtl.dataset import QNLI
from metal.mmtl.task import Task
from metal.mmtl.trainer import MultitaskTrainer

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


### Config

In [3]:
bert_model = 'bert-base-uncased'
bert_model_output_shape = 768
max_len = 512
batch_size = 32
lr = 1e-5
n_epochs = 1

In [4]:
dataloaders = {}
for split in ['test', 'dev']: # 'train'
    dataset = QNLI(
        split=split,
        bert_model=bert_model,
        max_len=max_len
    )
    dataloaders[split] = dataset.get_dataloader(batch_size=batch_size, shuffle=True)

100%|██████████| 5463/5463 [00:07<00:00, 773.54it/s]
100%|██████████| 5463/5463 [00:06<00:00, 792.79it/s]


In [None]:
bert_encoder = BertEncoder(bert_model)
ranking_head = nn.Linear(in_features=bert_model_output_shape, out_features=2, bias=False)
ranking_task = Task("ranking",
                    [dataloaders['dev'], dataloaders['dev'], None],
                    bert_encoder,
                    ranking_head)

In [None]:
tasks = [ranking_task]
model = MetalModel(tasks, verbose=False)
trainer = MultitaskTrainer()
trainer.train_model(model, tasks, n_epochs=n_epochs, lr=lr, progress_bar=True)

In [None]:
for batch in dataloaders['dev']:
    X, Y = batch
    print(model(X, ['foo_task']))
    print(model.calculate_loss(X, Y, ['foo_task']))    
    print(model.calculate_output(X, ['foo_task']))    
    break