In [1]:
import os 

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.functional as F
from tqdm import tqdm
from pytorch_pretrained_bert import BertTokenizer, BertModel

from metal.mmtl.glue_datasets import RTEDataset
from metal.end_model import EndModel

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
model = 'bert-base-uncased' # also try bert-base-multilingual-cased (recommended)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
src_path = os.path.join(os.environ['GLUEDATA'], 'RTE/{}.tsv')
dataset = {}
for split in ['train', 'dev']:
    dataset[split] = RTEDataset(src_path.format(split), tokenizer)
    dataset[split].load_data()
    dataset[split].preprocess_data()

100%|██████████| 2490/2490 [00:03<00:00, 778.08it/s]
100%|██████████| 277/277 [00:00<00:00, 622.80it/s]


In [3]:
# dataset['dev'].get_dataloader

In [4]:
# for batch in dataset['dev'].get_dataloader(batch_size=2):
#     (sent1, sent1_mask, sent2, sent2_mask), label = batch
#     print(batch)
#     break

In [5]:
class LinearSelfAttn(nn.Module):
    def __init__(self, input_size):
        super(LinearSelfAttn, self).__init__()
        self.linear = nn.Linear(input_size, 1)
        self.softmax = nn.Softmax(1)

    def forward(self, x, x_mask):
        scores = self.linear(x).view(x.size(0), x.size(1))
        scores.data.masked_fill_(x_mask.data, -float('inf'))
        alpha = self.softmax(scores)
        return alpha.unsqueeze(1).bmm(x).squeeze(1)

In [6]:
class BilinearSelfAttn(nn.Module):
    def __init__(self, x_size, y_size):
        super(BilinearSelfAttn, self).__init__()
        self.linear = nn.Linear(y_size, x_size)
        self.softmax = nn.Softmax(1)

    def forward(self, x, y, x_mask):
        Wy = self.linear(y)
        xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2)
        xWy.data.masked_fill_(x_mask.data, -float('inf'))
        beta = self.softmax(xWy)
        return beta.unsqueeze(1).bmm(x).squeeze(1)

In [7]:
# class SAN(nn.Module):
#     def __init__(self, emb_size=100, hidden_size=100, num_classes=2, k=5):
#         super(SAN, self).__init__()
#         self.bert_model = BertModel.from_pretrained("bert-base-uncased")
#         self.sent1_attn = LinearSelfAttn(input_size=emb_size)
#         self.sent2_attn = BilinearSelfAttn(emb_size, emb_size)
#         self.final_linear = nn.Linear(emb_size * 4, num_classes)
#         self.rnn = rnn = nn.GRU(emb_size, hidden_size, 1, batch_first=True)
#         self.k = k
#         self.num_classes = num_classes
#         self.softmax = nn.Softmax(1)

#     def forward(self, X):
#         sent1, sent1_mask, sent2, sent2_mask = X
#         batch_size = sent1.size(0)
#         sent1, _ = self.bert_model(
#             sent1, sent1_mask, 1 - sent1_mask, output_all_encoded_layers=False
#         )
#         sent2, _ = self.bert_model(
#             sent2, sent2_mask, 1 - sent2_mask, output_all_encoded_layers=False
#         )
# #         print(sent1)
# #         print(sent1.size())
# #         print(sent2.size())
#         outputs = []
#         # sk (batch * embed_size)
#         sk = self.sent1_attn(sent1, sent1_mask.byte())

#         for i in range(self.k):
#             xk = self.sent2_attn(sent2, sk, sent2_mask.byte())
# #             print(sk.size(), xk.size())
#             _, sk = self.rnn(xk.unsqueeze(1), sk.unsqueeze(0))
#             sk = sk.squeeze(0)
# #             print(sk.size())
#             outputs.append((sk, xk))

#         res = torch.zeros((batch_size, self.num_classes))

#         for i in range(self.k):
#             sk, xk = outputs[i]
#             f = self.softmax(
#                 self.final_linear(torch.cat((sk, xk, torch.abs(sk - xk), sk * xk), 1))
#             )
#             res += f
#         return res / self.k

In [8]:
class SAN(nn.Module):
    def __init__(self, emb_size=100, hidden_size=100, num_classes=2, k=5):
        super(SAN, self).__init__()
        self.bert_model = BertModel.from_pretrained("bert-base-uncased")
        self.sent1_attn = LinearSelfAttn(input_size=emb_size)
        self.sent2_attn = BilinearSelfAttn(emb_size, emb_size)
        self.final_linear = nn.Linear(emb_size * 4, num_classes)
        self.rnn = rnn = nn.GRU(emb_size, hidden_size, 1, batch_first=True)
        self.k = k
        self.num_classes = num_classes
        self.softmax = nn.Softmax(1)
        
        for param in self.bert_model.parameters():
            param.requires_grad = False

    def forward(self, X):
#         print("!")
        sent1, sent1_mask, sent2, sent2_mask = X
        batch_size = sent1.size(0)
        sent1, _ = self.bert_model(
            sent1, sent1_mask, 1 - sent1_mask, output_all_encoded_layers=False
        )
        sent2, _ = self.bert_model(
            sent2, sent2_mask, 1 - sent2_mask, output_all_encoded_layers=False
        )
        res = sent1.new_zeros((batch_size, self.num_classes))
#         res = torch.zeros((batch_size, self.num_classes))
        sk = self.sent1_attn(sent1, sent1_mask.byte())

        for i in range(self.k):
            xk = self.sent2_attn(sent2, sk, sent2_mask.byte())
            _, sk = self.rnn(xk.unsqueeze(1), sk.unsqueeze(0))
            sk = sk.squeeze(0)

            f = self.softmax(
                self.final_linear(torch.cat((sk, xk, torch.abs(sk - xk), sk * xk), 1))
            )
            res += f

        return res / self.k

In [9]:
san = SAN(emb_size = 768, hidden_size = 768)

In [10]:
end_model = EndModel(
    [2], input_module=san, seed=123, device="cuda", skip_head=True, input_relu=False
)


Network architecture:
SAN(
  (bert_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertInte

In [11]:
# end_model.train_model(
#     dataset["train"].get_dataloader(batch_size=10),
#     valid_data=dataset["dev"].get_dataloader(batch_size=10),
# #     lr=5e-5
# )

In [12]:
end_model.train_model(
    dataset["train"].get_dataloader(batch_size=32),
    valid_data=dataset["dev"].get_dataloader(batch_size=32),
#     dataloaders["train"],
#     valid_data=dataloaders["dev"],
    lr=5e-5,
    l2=0,
    n_epochs=3,
#     checkpoint_metric="model/train/loss",
    checkpoint_metric="valid/accuracy",
    log_unit="batches",
    checkpoint_metric_mode="max",
    verbose=True,
    progress_bar=True,
)

Using GPU...
[1 bat (0.00 epo)]: TRAIN:[loss=0.699] VALID:[accuracy=0.534]
Saving model at iteration 1 with best score 0.534
[2 bat (0.00 epo)]: TRAIN:[loss=0.699] VALID:[accuracy=0.534]
[3 bat (0.00 epo)]: TRAIN:[loss=0.688] VALID:[accuracy=0.538]
Saving model at iteration 3 with best score 0.538
[4 bat (0.01 epo)]: TRAIN:[loss=0.699] VALID:[accuracy=0.505]
[5 bat (0.01 epo)]: TRAIN:[loss=0.688] VALID:[accuracy=0.498]
[6 bat (0.01 epo)]: TRAIN:[loss=0.686] VALID:[accuracy=0.505]
[7 bat (0.01 epo)]: TRAIN:[loss=0.702] VALID:[accuracy=0.513]
[8 bat (0.01 epo)]: TRAIN:[loss=0.705] VALID:[accuracy=0.505]
[9 bat (0.01 epo)]: TRAIN:[loss=0.689] VALID:[accuracy=0.520]
[10 bat (0.02 epo)]: TRAIN:[loss=0.700] VALID:[accuracy=0.513]
[11 bat (0.02 epo)]: TRAIN:[loss=0.686] VALID:[accuracy=0.527]
[12 bat (0.02 epo)]: TRAIN:[loss=0.691] VALID:[accuracy=0.534]
[13 bat (0.02 epo)]: TRAIN:[loss=0.689] VALID:[accuracy=0.542]
Saving model at iteration 13 with best score 0.542
[14 bat (0.02 epo)]: TRAIN

[123 bat (0.20 epo)]: TRAIN:[loss=0.679] VALID:[accuracy=0.491]
[124 bat (0.20 epo)]: TRAIN:[loss=0.686] VALID:[accuracy=0.505]
[125 bat (0.20 epo)]: TRAIN:[loss=0.686] VALID:[accuracy=0.513]
[126 bat (0.20 epo)]: TRAIN:[loss=0.665] VALID:[accuracy=0.513]
[127 bat (0.20 epo)]: TRAIN:[loss=0.688] VALID:[accuracy=0.509]
[128 bat (0.21 epo)]: TRAIN:[loss=0.691] VALID:[accuracy=0.516]
[129 bat (0.21 epo)]: TRAIN:[loss=0.688] VALID:[accuracy=0.531]
[130 bat (0.21 epo)]: TRAIN:[loss=0.683] VALID:[accuracy=0.538]
[131 bat (0.21 epo)]: TRAIN:[loss=0.679] VALID:[accuracy=0.534]
[132 bat (0.21 epo)]: TRAIN:[loss=0.699] VALID:[accuracy=0.527]
[133 bat (0.21 epo)]: TRAIN:[loss=0.667] VALID:[accuracy=0.531]
[134 bat (0.22 epo)]: TRAIN:[loss=0.694] VALID:[accuracy=0.513]
[135 bat (0.22 epo)]: TRAIN:[loss=0.685] VALID:[accuracy=0.531]
[136 bat (0.22 epo)]: TRAIN:[loss=0.674] VALID:[accuracy=0.563]
[137 bat (0.22 epo)]: TRAIN:[loss=0.685] VALID:[accuracy=0.570]
[138 bat (0.22 epo)]: TRAIN:[loss=0.697]

[249 bat (0.40 epo)]: TRAIN:[loss=0.702] VALID:[accuracy=0.542]
[250 bat (0.40 epo)]: TRAIN:[loss=0.682] VALID:[accuracy=0.549]
[251 bat (0.40 epo)]: TRAIN:[loss=0.666] VALID:[accuracy=0.563]
[252 bat (0.40 epo)]: TRAIN:[loss=0.683] VALID:[accuracy=0.567]
[253 bat (0.41 epo)]: TRAIN:[loss=0.684] VALID:[accuracy=0.570]
[254 bat (0.41 epo)]: TRAIN:[loss=0.670] VALID:[accuracy=0.574]
[255 bat (0.41 epo)]: TRAIN:[loss=0.715] VALID:[accuracy=0.574]
[256 bat (0.41 epo)]: TRAIN:[loss=0.686] VALID:[accuracy=0.574]
[257 bat (0.41 epo)]: TRAIN:[loss=0.703] VALID:[accuracy=0.570]
[258 bat (0.41 epo)]: TRAIN:[loss=0.685] VALID:[accuracy=0.567]
[259 bat (0.42 epo)]: TRAIN:[loss=0.667] VALID:[accuracy=0.560]
[260 bat (0.42 epo)]: TRAIN:[loss=0.728] VALID:[accuracy=0.545]
[261 bat (0.42 epo)]: TRAIN:[loss=0.682] VALID:[accuracy=0.538]
[262 bat (0.42 epo)]: TRAIN:[loss=0.676] VALID:[accuracy=0.542]
[263 bat (0.42 epo)]: TRAIN:[loss=0.671] VALID:[accuracy=0.545]
[264 bat (0.42 epo)]: TRAIN:[loss=0.705]

[378 bat (0.61 epo)]: TRAIN:[loss=0.668] VALID:[accuracy=0.538]
[379 bat (0.61 epo)]: TRAIN:[loss=0.675] VALID:[accuracy=0.545]
[380 bat (0.61 epo)]: TRAIN:[loss=0.665] VALID:[accuracy=0.534]
[381 bat (0.61 epo)]: TRAIN:[loss=0.690] VALID:[accuracy=0.552]
[382 bat (0.61 epo)]: TRAIN:[loss=0.691] VALID:[accuracy=0.552]
[383 bat (0.62 epo)]: TRAIN:[loss=0.669] VALID:[accuracy=0.556]
[384 bat (0.62 epo)]: TRAIN:[loss=0.672] VALID:[accuracy=0.552]
[385 bat (0.62 epo)]: TRAIN:[loss=0.681] VALID:[accuracy=0.560]
[386 bat (0.62 epo)]: TRAIN:[loss=0.678] VALID:[accuracy=0.560]
[387 bat (0.62 epo)]: TRAIN:[loss=0.669] VALID:[accuracy=0.549]
[388 bat (0.62 epo)]: TRAIN:[loss=0.674] VALID:[accuracy=0.542]
[389 bat (0.62 epo)]: TRAIN:[loss=0.692] VALID:[accuracy=0.538]
[390 bat (0.63 epo)]: TRAIN:[loss=0.701] VALID:[accuracy=0.534]
Restoring best model from iteration 150 with score 0.596
Finished Training
Accuracy: 0.596
        y=1    y=2   
 l=1    31     100   
 l=2    12     134   


In [13]:
# Test end model
end_model.score(dataset["dev"].get_dataloader(batch_size=10), metric=["accuracy", "precision", "recall", "f1"])

Accuracy: 0.596
Precision: 0.721
Recall: 0.237
F1: 0.356
        y=1    y=2   
 l=1    31     100   
 l=2    12     134   


[0.5956678700361011,
 0.7209302325581395,
 0.2366412213740458,
 0.3563218390804598]