In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset
import numpy as np
import pandas as pd
import metal
import os
from pytorch_pretrained_bert import BertTokenizer, BertModel

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
from dataset import BERTDataset

In [4]:
model = 'bert-base-uncased' # also try bert-base-multilingual-cased (recommended)
src_path = os.path.join(os.environ['GLUEDATA'], 'RTE/{}.tsv')
dataloaders = {}
for split in ['train', 'dev']: #, 'train', 'test']:
    label_idx = 3 if split in ['train', 'dev'] else -1
    dataset = BERTDataset(
        src_path.format(split),
        sent1_idx=1,
        sent2_idx=2,
        label_idx=label_idx,
        skip_rows=1,
        label_fn=lambda label: 1 if label=='entailment' else 2,
        max_len=128,
    )
    dataloaders[split] = dataset.get_dataloader(batch_size=32)

100%|██████████| 2490/2490 [00:02<00:00, 1011.58it/s]
100%|██████████| 277/277 [00:00<00:00, 797.56it/s]


In [5]:
import torch.nn as nn
from metal.end_model import EndModel

hidden_dropout_prob = 0.1

class BertEncoder(nn.Module):
    def __init__(self):
        super(BertEncoder, self).__init__()
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(hidden_dropout_prob)
#         for param in self.bert_model.parameters():
#             param.requires_grad = False

    def forward(self, data):
        tokens, segments, mask = data
        # TODO: check if we should return all layers or just last hidden representation 
        _, hidden_layer = self.bert_model(tokens, segments, mask, output_all_encoded_layers=False)
        hidden_layer = self.dropout(hidden_layer)

        return hidden_layer

In [6]:
encoder_module = BertEncoder()
end_model = EndModel(
    [768, 2],
    input_module=encoder_module,
    seed=123,
    skip_head=False,
    input_relu=False,
    input_batchnorm=False,
    verbose=False,
    device=torch.device("cuda")
)

In [7]:
end_model.train_model(
    dataloaders["train"],
    valid_data=dataloaders["dev"],
    lr=5e-5,
    l2=0,
    n_epochs=5,
#     checkpoint_metric="model/train/loss",
    checkpoint_metric="valid/accuracy",
    log_unit="batches",
    checkpoint_metric_mode="max",
    verbose=True,
    progress_bar=True,
)

Using GPU...
[1 bat (0.00 epo)]: TRAIN:[loss=0.682] VALID:[accuracy=0.527]
Saving model at iteration 1 with best score 0.527
[2 bat (0.00 epo)]: TRAIN:[loss=0.790] VALID:[accuracy=0.527]
[3 bat (0.00 epo)]: TRAIN:[loss=0.662] VALID:[accuracy=0.527]
[4 bat (0.00 epo)]: TRAIN:[loss=0.736] VALID:[accuracy=0.491]
[5 bat (0.01 epo)]: TRAIN:[loss=0.735] VALID:[accuracy=0.473]
[6 bat (0.01 epo)]: TRAIN:[loss=0.728] VALID:[accuracy=0.473]
[7 bat (0.01 epo)]: TRAIN:[loss=0.665] VALID:[accuracy=0.473]
[8 bat (0.01 epo)]: TRAIN:[loss=0.710] VALID:[accuracy=0.473]
[9 bat (0.01 epo)]: TRAIN:[loss=0.644] VALID:[accuracy=0.473]
[10 bat (0.01 epo)]: TRAIN:[loss=0.704] VALID:[accuracy=0.473]
[11 bat (0.01 epo)]: TRAIN:[loss=0.744] VALID:[accuracy=0.473]
[12 bat (0.01 epo)]: TRAIN:[loss=0.733] VALID:[accuracy=0.473]
[13 bat (0.02 epo)]: TRAIN:[loss=0.681] VALID:[accuracy=0.487]
[14 bat (0.02 epo)]: TRAIN:[loss=0.714] VALID:[accuracy=0.509]
[15 bat (0.02 epo)]: TRAIN:[loss=0.708] VALID:[accuracy=0.560]
S

[117 bat (0.14 epo)]: TRAIN:[loss=0.435] VALID:[accuracy=0.686]
[118 bat (0.14 epo)]: TRAIN:[loss=0.393] VALID:[accuracy=0.668]
[119 bat (0.14 epo)]: TRAIN:[loss=0.391] VALID:[accuracy=0.668]
[120 bat (0.14 epo)]: TRAIN:[loss=0.299] VALID:[accuracy=0.646]
[121 bat (0.15 epo)]: TRAIN:[loss=0.601] VALID:[accuracy=0.635]
[122 bat (0.15 epo)]: TRAIN:[loss=0.380] VALID:[accuracy=0.635]
[123 bat (0.15 epo)]: TRAIN:[loss=0.569] VALID:[accuracy=0.657]
[124 bat (0.15 epo)]: TRAIN:[loss=0.480] VALID:[accuracy=0.664]
[125 bat (0.15 epo)]: TRAIN:[loss=0.398] VALID:[accuracy=0.653]
[126 bat (0.15 epo)]: TRAIN:[loss=0.541] VALID:[accuracy=0.657]
[127 bat (0.15 epo)]: TRAIN:[loss=0.555] VALID:[accuracy=0.697]
[128 bat (0.15 epo)]: TRAIN:[loss=0.398] VALID:[accuracy=0.653]
[129 bat (0.16 epo)]: TRAIN:[loss=0.291] VALID:[accuracy=0.603]
[130 bat (0.16 epo)]: TRAIN:[loss=0.456] VALID:[accuracy=0.606]
[131 bat (0.16 epo)]: TRAIN:[loss=0.489] VALID:[accuracy=0.621]
[132 bat (0.16 epo)]: TRAIN:[loss=0.427]

[246 bat (0.30 epo)]: TRAIN:[loss=0.167] VALID:[accuracy=0.661]
[247 bat (0.30 epo)]: TRAIN:[loss=0.131] VALID:[accuracy=0.668]
[248 bat (0.30 epo)]: TRAIN:[loss=0.114] VALID:[accuracy=0.675]
[249 bat (0.30 epo)]: TRAIN:[loss=0.104] VALID:[accuracy=0.664]
[250 bat (0.30 epo)]: TRAIN:[loss=0.323] VALID:[accuracy=0.650]
[251 bat (0.30 epo)]: TRAIN:[loss=0.164] VALID:[accuracy=0.650]
[252 bat (0.30 epo)]: TRAIN:[loss=0.138] VALID:[accuracy=0.653]
[253 bat (0.30 epo)]: TRAIN:[loss=0.204] VALID:[accuracy=0.653]
[254 bat (0.31 epo)]: TRAIN:[loss=0.248] VALID:[accuracy=0.653]
[255 bat (0.31 epo)]: TRAIN:[loss=0.136] VALID:[accuracy=0.653]
[256 bat (0.31 epo)]: TRAIN:[loss=0.129] VALID:[accuracy=0.657]
[257 bat (0.31 epo)]: TRAIN:[loss=0.127] VALID:[accuracy=0.657]
[258 bat (0.31 epo)]: TRAIN:[loss=0.161] VALID:[accuracy=0.653]
[259 bat (0.31 epo)]: TRAIN:[loss=0.097] VALID:[accuracy=0.653]
[260 bat (0.31 epo)]: TRAIN:[loss=0.128] VALID:[accuracy=0.653]
[261 bat (0.31 epo)]: TRAIN:[loss=0.204]

[375 bat (0.45 epo)]: TRAIN:[loss=0.106] VALID:[accuracy=0.679]
[376 bat (0.45 epo)]: TRAIN:[loss=0.121] VALID:[accuracy=0.682]
[377 bat (0.45 epo)]: TRAIN:[loss=0.196] VALID:[accuracy=0.690]
[378 bat (0.46 epo)]: TRAIN:[loss=0.070] VALID:[accuracy=0.697]
[379 bat (0.46 epo)]: TRAIN:[loss=0.114] VALID:[accuracy=0.693]
[380 bat (0.46 epo)]: TRAIN:[loss=0.208] VALID:[accuracy=0.697]
[381 bat (0.46 epo)]: TRAIN:[loss=0.092] VALID:[accuracy=0.693]
[382 bat (0.46 epo)]: TRAIN:[loss=0.063] VALID:[accuracy=0.675]
[383 bat (0.46 epo)]: TRAIN:[loss=0.077] VALID:[accuracy=0.661]
[384 bat (0.46 epo)]: TRAIN:[loss=0.106] VALID:[accuracy=0.664]
[385 bat (0.46 epo)]: TRAIN:[loss=0.280] VALID:[accuracy=0.671]
[386 bat (0.47 epo)]: TRAIN:[loss=0.251] VALID:[accuracy=0.679]
[387 bat (0.47 epo)]: TRAIN:[loss=0.294] VALID:[accuracy=0.704]
[388 bat (0.47 epo)]: TRAIN:[loss=0.148] VALID:[accuracy=0.697]
[389 bat (0.47 epo)]: TRAIN:[loss=0.071] VALID:[accuracy=0.668]
[390 bat (0.47 epo)]: TRAIN:[loss=0.023]

In [8]:
# Test end model
end_model.score(dataloaders["dev"], metric=["accuracy", "precision", "recall", "f1"])

Accuracy: 0.711
Precision: 0.692
Recall: 0.815
F1: 0.748
        y=1    y=2   
 l=1    119    27    
 l=2    53     78    


[0.7111913357400722, 0.6918604651162791, 0.815068493150685, 0.7484276729559748]