In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import metal
import os
from pytorch_pretrained_bert import BertTokenizer, BertModel
from dataset import BERTDataset

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
model = 'bert-base-uncased'

def get_dataloaders(name):
    if name=="QQP":
        src_path = os.path.join(os.environ['GLUEDATA'], 'QQP/{}.tsv')
        dataloaders = {}
        for split in ['train', 'dev']: #, 'train', 'test']:
            label_idx = 5 if split in ['train', 'dev'] else -1
            dataset = BERTDataset(
                src_path.format(split),
                sent1_idx=3,
                sent2_idx=4,
                label_idx=label_idx,
                skip_rows=1,
                label_fn=lambda label: 1 if label=='0' else 2,
                max_len=128,
            )
            dataloaders[split] = dataset.get_dataloader(batch_size=32)
    if name=="MNLI":
        src_path = os.path.join(os.environ['GLUEDATA'], 'MNLI/{}.tsv')
        labels = ["contradiction", "entailment", "neutral"]

        dataloaders = {}
        for split in ['train', 'dev_matched']: #, 'train', 'test']:
            label_idx = 11 if split in ['train', 'dev_matched'] else -1
            dataset = BERTDataset(
                src_path.format(split),
                sent1_idx=8,
                sent2_idx=9,
                label_idx=label_idx,
                skip_rows=1,
                label_fn=lambda label: labels.index(label)+1,
                max_len=128,
            )
            dataloaders[split] = dataset.get_dataloader(batch_size=32)
    if name=="RTE":
        src_path = os.path.join(os.environ['GLUEDATA'], 'RTE/{}.tsv')
        dataloaders = {}
        
        for split in ['train', 'dev']: #, 'train', 'test']:
            label_idx = 3 if split in ['train', 'dev'] else -1
            dataset = BERTDataset(
                src_path.format(split),
                sent1_idx=1,
                sent2_idx=2,
                label_idx=label_idx,
                skip_rows=1,
                label_fn=lambda label: 1 if label=='entailment' else 2,
                max_len=128,
            )
            dataloaders[split] = dataset.get_dataloader(batch_size=32)
    if name=="WNLI":
        src_path = os.path.join(os.environ['GLUEDATA'], 'WNLI/{}.tsv')
        dataloaders = {}
        for split in ['train', 'dev']: #, 'train', 'test']:
            label_idx = 0 if split in ['train', 'dev'] else -1
            dataset = BERTDataset(
                src_path.format(split),
                sent1_idx=1,
                sent2_idx=2,
                label_idx=label_idx,
                skip_rows=1,
                label_fn=lambda label: 1 if label=='0' else 2,
                max_len=128,
            )
            dataloaders[split] = dataset.get_dataloader(batch_size=32)
    return dataloaders

In [4]:
import torch.nn as nn

hidden_dropout_prob = 0.1

class BertEncoder(nn.Module):
    def __init__(self):
        super(BertEncoder, self).__init__()
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(hidden_dropout_prob)

    def forward(self, data):
        tokens, segments, mask = data
        _, hidden_layer = self.bert_model(tokens, segments, mask, output_all_encoded_layers=False)
        hidden_layer = self.dropout(hidden_layer)

        return hidden_layer

In [6]:
class LinearLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearLayer, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.linear(x)

In [7]:
hidden_bert_size=768

def get_task_head(name):
    if name=="QQP":
        return LinearLayer(hidden_bert_size, 2)
    if name=="MNLI":
        return LinearLayer(hidden_bert_size, 3)
    if name=="RTE":
        return LinearLayer(hidden_bert_size, 2)
    if name=="WNLI":
        return LinearLayer(hidden_bert_size, 2)

In [8]:
from metal.mmtl.task import Task
# from BERT_tasks import create_task

all_tasks = ["RTE", "WNLI", "QQP", "MNLI"]

task_name = all_tasks[0]
dataloaders = get_dataloaders(task_name)
task_head = get_task_head(task_name)

task = Task(task_name, [dataloaders["train"], dataloaders["dev"], None], BertEncoder(), task_head)
tasks = [task]

TypeError: __init__() missing 1 required positional argument: 'dataset_split'

In [None]:
from metal.end_model import EndModel
from metal.mmtl.metal_model import MetalModel
from metal.mmtl.trainer import MultitaskTrainer

model = MetalModel(tasks, verbose=True)
trainer = MultitaskTrainer()
trainer.train_model(
    model, 
    tasks, 
    n_epochs=5, 
    lr=5e-5,
    progress_bar=True,
    log_valid_metrics=[task_name + "/accuracy"],
    checkpoint_metric=task_name + "/accuracy",
    checkpoint_metric_mode="max",
    verbose=True,
)