## Working, but ugly :(

In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset
import numpy as np
import pandas as pd
import metal
import os
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

In [5]:
class BertDataset(Dataset):
    def __init__(self, src_path, tokenizer):
        super(BertDataset, self).__init__()
        self.src_path = src_path
        self.tokenizer = tokenizer
        self.raw_data = None
        self.tokens = None
        self.segments = None
        self.labels = None
        
    def __getitem__(self, index):
        return (self.tokens[index], self.segments[index]), self.labels[index]

    def __len__(self):
        return len(self.tokens)

    def load_data(self):
        self.raw_data = pd.read_csv(
            self.src_path, sep='\t', header=0,
            index_col=0, error_bad_lines=False, warn_bad_lines=False
        )
        if 'label' not in self.raw_data.columns:
            # add dummy column to match data input format
            self.raw_data['label'] = ['entailment'] * self.__len__()
    
    def preprocess_data(self):
        raise NotImplementedError

In [6]:
import codecs
def load_tsv(data_file, 
             sent1_idx=7, 
             sent2_idx=8, 
             label_idx=9, 
             skip_rows=1, 
             delimiter='\t', 
             label_fn=lambda x: float(x)/5,
            ):
    """ Loads and tokenizes .tsv dataset into BERT-friendly sentences / segments.
    
    Args:
        data_file: path to .tsv file
        sent1_idx: tsv index for sentence1
        sent2_idx: tsv index for sentence2
        label_idx: tsv index for label field
        skip_rows: number of rows to skip (i.e. header rows) in .tsv
        delimiter: delimiter between columns (likely '\t') for tab-separated-values
        label_fn: function mapping from raw labels to desired format
    Returns:
        sentences: list of indexed sentences
        semgnets: list of 0/1 segment maps
    """
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    tokens, segments, labels = [], [], []
    with codecs.open(data_file, 'r', 'utf-8') as data_fh:

        # skip "header" rows
        for _ in range(skip_rows):
            data_fh.readline()
            
        # process data rows
        for row_idx, row in enumerate(data_fh):
            row = row.strip().split(delimiter)
            
            # tokenize and convert each sentence to ids
            sent1_tokenized = tokenizer.tokenize(row[sent1_idx])
            sent1_ids = tokenizer.convert_tokens_to_ids(sent1_tokenized)

            sent2_tokenized = tokenizer.tokenize(row[sent2_idx])
            sent2_ids = tokenizer.convert_tokens_to_ids(sent2_tokenized)
            
            # combine sentence pair
            sent = sent2_ids + sent2_ids
            
            # sentence-pair segments
            seg = [0]*len(sent1_tokenized) + [1]*len(sent2_tokenized)
            
            # process labels
            try:
                label = label_fn(row[label_idx])
            except:
                import pdb; pdb.set_trace()
            
            tokens.append(sent)
            segments.append(seg)
            labels.append(label)
            
    return tokens, segments, labels

In [7]:
class STSBDataset(BertDataset):
    def __init__(self, src_path, tokenizer):
        super(STSBDataset, self).__init__(src_path, tokenizer)

    def preprocess_data(self):
        # TODO: fix load_tsv abstraction
        self.tokens, self.segments, self.labels = load_tsv(self.src_path)


In [8]:
def collate_fn(batch, max_len=-1):
    batch_size = len(batch)

    max_sent_len = int(np.max([len(x) for x in batch]))
    if max_len > 0 and max_len < max_sent_len:
        max_sent_len = max_len

    idx_matrix = np.zeros((batch_size, max_sent_len), dtype=np.int)
    seg_matrix = np.zeros((batch_size, max_sent_len), dtype=np.float32)
    label_matrix = np.zeros((batch_size, 1)) # TODO: does this work on other tasks?

    for idx1 in np.arange(len(batch)):
        (tokens, segments), labels = batch[idx1]
        for idx2 in np.arange(len(tokens)):
            if idx2 >= max_sent_len:
                break
            idx_matrix[idx1, idx2] = tokens[idx2]
            seg_matrix[idx1, idx2] = segments[idx2]

        label_matrix[idx1] = labels

    idx_matrix = torch.LongTensor(idx_matrix)
    seg_matrix = torch.LongTensor(seg_matrix)
    mask_matrix = torch.eq(idx_matrix.data, -1).long()
    label_matrix = torch.FloatTensor(label_matrix)
    
    return (idx_matrix, seg_matrix, mask_matrix), label_matrix

In [14]:
model = 'bert-base-uncased' # also try bert-base-multilingual-cased (recommended)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
src_path = os.path.join(os.environ['GLUEDATA'], 'STS-B/{}.tsv')
dataloaders = {}
for split in ['train', 'dev']:
    dataset = STSBDataset(src_path.format(split), tokenizer)
#     dataset.load_data() # TODO: make this work with dataloader
    dataset.preprocess_data()
    dataloaders[split] = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)

02/06/2019 19:10:11 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /afs/cs.stanford.edu/u/vschen/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
02/06/2019 19:10:12 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /afs/cs.stanford.edu/u/vschen/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
02/06/2019 19:10:14 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /afs/cs.stanford.edu/u/vschen/.pytorch_pretrained_bert/26bc1ad6c0a

In [15]:
import torch.nn as nn
import torch.nn.functional as F
from metal.end_model import EndModel

class BertEncoder(nn.Module):
    def __init__(self):
        super(BertEncoder, self).__init__()
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
        
    def forward(self, data):
        tokens, segments, mask = data
        # TODO: check if we should return all layers or just last hidden representation 
        _, hidden_layer = self.bert_model(tokens, segments, mask)
        return hidden_layer

class STSBHead(EndModel):     
    def __init__(self, output_dims, **kwargs):
        super(STSBHead, self).__init__(output_dims, **kwargs)
        self.criteria = nn.modules.loss.MSELoss()

    def _loss(self, data, Y):
        output = self.forward(data)
        prediction = F.sigmoid(output)
        return self.criteria(prediction, Y)

    def _get_loss_fn(self):
        # This self.preprocess_Y allows us to not handle preprocessing
        # in a custom dataloader, but decreases speed a bit
        return self._loss

In [16]:
encoder_module = BertEncoder()
end_model = STSBHead(
    [768, 1],
    input_module=encoder_module,
    seed=123,
    use_cuda=False,
    skip_head=False,
    input_relu=False,
    input_batchnorm=False,
    verbose=False
)

02/06/2019 19:10:17 - INFO - pytorch_pretrained_bert.modeling -   loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /afs/cs.stanford.edu/u/vschen/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
02/06/2019 19:10:17 - INFO - pytorch_pretrained_bert.modeling -   extracting archive file /afs/cs.stanford.edu/u/vschen/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmp0o0hz3vy
02/06/2019 19:10:21 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_la

In [None]:
end_model.train_model(dataloaders['train'], valid_data=dataloaders['dev'],
                      lr=0.01, l2=0.01, 
                      batch_size=256, 
                      n_epochs=5, checkpoint_metric='valid/loss',
                      checkpoint_metric_mode='max',
                      verbose=True, progress_bar=True
                    )