In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import metal
import os
from pytorch_pretrained_bert import BertTokenizer, BertModel
from dataset import QQPDataset, RTEDataset, WNLIDataset, MNLIDataset

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
all_tasks = ["RTE", "WNLI", "QQP", "MNLI"]

task_name = all_tasks[3]

In [None]:
if task_name == "RTE":
    train_ds = RTEDataset(split='train', bert_model='bert-base-uncased', max_len=128)
    train_dl, dev_dl = train_ds.get_dataloader(split_prop=0.8, batch_size=32)

    test_ds = RTEDataset(split='dev', bert_model='bert-base-uncased', max_len=128)
    test_dl = test_ds.get_dataloader(batch_size=32)

if task_name == "WNLI":
    train_ds = WNLIDataset(split='train', bert_model='bert-base-uncased', max_len=128)
    train_dl, dev_dl = train_ds.get_dataloader(split_prop=0.8, batch_size=32)

    test_ds = WNLIDataset(split='dev', bert_model='bert-base-uncased', max_len=128)
    test_dl = test_ds.get_dataloader(batch_size=32)

if task_name == "QQP":
    train_ds = QQPDataset(split='train', bert_model='bert-base-uncased', max_len=128)
    train_dl, dev_dl = train_ds.get_dataloader(split_prop=0.8, batch_size=32)

    test_ds = QQPDataset(split='dev', bert_model='bert-base-uncased', max_len=128)
    test_dl = test_ds.get_dataloader(batch_size=32)

if task_name == "MNLI":
    train_ds = MNLIDataset(split='train', bert_model='bert-base-uncased', max_len=128)
    train_dl, dev_dl = train_ds.get_dataloader(split_prop=0.8, batch_size=32)

    test_ds = MNLIDataset(split='dev_matched', bert_model='bert-base-uncased', max_len=128)
    test_dl = test_ds.get_dataloader(batch_size=32)
    
dataloaders = {
    'train': train_dl,
    'valid': dev_dl,
    'test': test_dl
}

 80%|████████  | 315760/392702 [03:42<00:40, 1879.18it/s]

In [None]:
import torch.nn as nn

model = 'bert-base-uncased'

hidden_dropout_prob = 0.1

class BertEncoder(nn.Module):
    def __init__(self):
        super(BertEncoder, self).__init__()
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(hidden_dropout_prob)

    def forward(self, data):
        tokens, segments, mask = data
        _, hidden_layer = self.bert_model(tokens, segments, mask, output_all_encoded_layers=False)
        hidden_layer = self.dropout(hidden_layer)

        return hidden_layer

In [None]:
class LinearLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearLayer, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.linear(x)

In [None]:
hidden_bert_size=768

def get_task_head(name):
    if name=="QQP":
        return LinearLayer(hidden_bert_size, 2)
    if name=="MNLI":
        return LinearLayer(hidden_bert_size, 3)
    if name=="RTE":
        return LinearLayer(hidden_bert_size, 2)
    if name=="WNLI":
        return LinearLayer(hidden_bert_size, 2)

In [None]:
from metal.mmtl.task import Task
# from BERT_tasks import create_task

# dataloaders = get_dataloaders(task_name)
task_head = get_task_head(task_name)

task = Task(task_name, dataloaders, BertEncoder(), task_head)
tasks = [task]

In [None]:
from metal.end_model import EndModel
from metal.mmtl.metal_model import MetalModel
from metal.mmtl.trainer import MultitaskTrainer

model = MetalModel(tasks, verbose=False)
trainer = MultitaskTrainer()
trainer.train_model(
    model, 
    tasks, 
    n_epochs=5, 
    lr=5e-5,
    progress_bar=False,
    log_every=0.25,
    score_every=0.25,
    checkpoint_best=True,
    #checkpoint_metric=task.name + "/valid/accuracy",
    #checkpoint_metric_mode="max",
    verbose=True,
    device="cuda",
)