In [112]:
import os 
from tqdm import tqdm
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.functional as F
import torch.utils.data as data

from metal import EndModel
from preprocess import NlpProcess

from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

### Goal: run standard MeTaL (single-task) EndModel for QNLI ranking task. 

### Preprocess data

In [166]:
class BertDataset(data.Dataset):
    def __init__(self, src_path, tokenizer):
        super(BertDataset, self).__init__()
        self.src_path = src_path
        self.tokenizer = tokenizer
        self.raw_data = None
        self.tokens = None
        self.segments = None
        self.labels = None
        
    def __getitem__(self, index):
        return (self.tokens[index], self.segments[index]), self.labels[index]
    
    def __len__(self):
        return self.raw_data.shape[0]

    def load_data(self):
        self.raw_data = pd.read_csv(
            self.src_path, sep='\t', header=0,
            index_col=0, error_bad_lines=False, warn_bad_lines=False
        )
        if 'label' not in self.raw_data.columns:
            # add dummy column to match data input format
            self.raw_data['label'] = ['entailment'] * self.__len__()
    
    def preprocess_data(self):
        raise NotImplementedError
            
class QNLIDataset(BertDataset):
    def __init__(self, src_path, tokenizer):
        super(QNLIDataset, self).__init__(src_path, tokenizer)

    def preprocess_data(self):
        tokens_indices = []
        segments = []
        #for i, row in tqdm(data.iterrows()):
        for i, row in tqdm(list(self.raw_data.iterrows())[:1]):
            question = row.question
            sentence = row.sentence
            tokenized_question = self.tokenizer.tokenize(question)
            tokenized_sentence = self.tokenizer.tokenize(sentence)
            tokenized_text = tokenized_question + ['[SEP]'] + tokenized_sentence
            tokens_indices.append(self.tokenizer.convert_tokens_to_ids(tokenized_text))
            segments.append(([0] * (len(tokenized_question)+1)) + ([1] * len(tokenized_sentence)))
            y_train[i, 0] = 1*(self.raw_data.label[i]=='entailment')
            y_train[i, 1] = 1*(self.raw_data.label[i]=='not_entailment')
        self.labels = labels
        self.tokens = tokens_indices
        self.segments = segments

In [185]:
def collate_fn(batch, max_len=-1):
    batch_size = len(batch)
    max_sent_len = int(np.max([len(x) for x in batch]))
    if max_len > 0 and max_len < max_sent_len:
        max_sent_len = max_len
    if type == "float":
        idx_matrix = np.zeros((batch_size, max_sent_len), dtype=np.float32)
        seg_matrix = np.zeros((batch_size, max_sent_len), dtype=np.float32)
        
    else:
        idx_matrix = np.zeros((batch_size, max_sent_len), dtype=np.int)
        seg_matrix = np.zeros((batch_size, max_sent_len), dtype=np.float32)

    for idx1 in np.arange(len(batch)):
        (tokens, segments), labels = batch[idx1]
        for idx2 in np.arange(len(tokens)):
            if idx2 >= max_sent_len:
                break
            idx_matrix[idx1, idx2] = tokens[idx2]
            seg_matrix[idx1, idx2] = segments[idx2]
            
    idx_matrix = torch.LongTensor(idx_matrix)
    seg_matrix = torch.LongTensor(seg_matrix)
    mask_matrix = torch.LongTensor(torch.eq(idx_matrix.data, -1))
    return idx_matrix, seg_matrix, mask_matrix

In [186]:
model = 'bert-base-uncased' # also try bert-base-multilingual-cased (recommended)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
src_path = os.path.join(os.environ['METALHOME'], 'data/QNLI/{}.tsv')
dataset = {}
dataloaders = {}
for split in ['train', 'test', 'dev']:
    dataset[split] = QNLIDataset(src_path.format(split), tokenizer)
    dataset[split].load_data()
    dataset[split].preprocess_data()
    dataloaders[split] = data.DataLoader(dataset[split], collate_fn=collate_fn)

100%|██████████| 1/1 [00:00<00:00, 13.70it/s]
100%|██████████| 1/1 [00:00<00:00, 188.14it/s]
100%|██████████| 1/1 [00:00<00:00, 458.49it/s]


In [187]:
dataset['dev'].raw_data.head()

Unnamed: 0_level_0,question,sentence,label
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,What came into force after the new constitutio...,"As of that day, the new constitution heralding...",entailment
1,What is the first major city in the stream of ...,The most important tributaries in this area ar...,not_entailment
2,What is the minimum required if you want to te...,In most provinces a second Bachelor's Degree s...,not_entailment
3,How was Temüjin kept imprisoned by the Tayichi...,The Tayichi'ud enslaved Temüjin (reportedly wi...,entailment
4,"What did Herr Gott, dich loben wir become know...","He paraphrased the Te Deum as ""Herr Gott, dich...",not_entailment


## Metal Model 

In [188]:
class BertEncoder(nn.Module):
    def __init__(self):
        super(BertEncoder, self).__init__()
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
        
    def forward(self, tokens, segments):
        # TODO: check if we should return all layers or just last hidden representation 
        _, hidden_layer = self.bert_model(tokens_tensors, segments_tensors)
        return hidden_layer

In [None]:
encoder_module = BertEncoder()
end_model = EndModel(
    [768, 2],
    input_module=encoder_module,
    seed=123,
    use_cuda=False,
    skip_head=False,
    input_relu=False,
    input_batchnorm=False,
    verbose=False
)

In [None]:
end_model.train_model(dataloaders['train'],
                      lr=0.01, l2=0.01, 
                      batch_size=256, 
                      n_epochs=5, checkpoint_metric='accuracy',
                      checkpoint_metric_mode='max')

In [None]:
# Test end model
end_model.score(test_loader, metric=["accuracy", "precision", "recall", "f1"])