In [1]:
### HEAVILY DRAWING FROM README OF https://github.com/huggingface/pytorch-pretrained-BERT
%load_ext autoreload
%autoreload 2

# pip install pytorch-pretrained-bert
import torch
import torch.nn as nn
import torch.nn.functional as F
from metal.end_model import EndModel
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import metal
import os
import numpy as np
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
from glue_datasets import BERT_Dataset

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


### Initialize filename and tokenizer for BERT vocabulary

In [2]:
SST_root = os.environ['GLUEDATA']+'/SST-2/'

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

### Create a torch-tensor representation of each sentence, where each token is an index to the BERT vocab

In [4]:
def pre_process(SST):
    tokens, segments, labels = [], [], []
    for ii in range(len(SST['sentence'])):
        sentence = SST['sentence'][ii]
        label = SST['label'][ii]
        tokenized = tokenizer.tokenize(sentence)
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized)
        tokens_tensor = torch.tensor([indexed_tokens])
        tokens.append(tokens_tensor)
        segments_tensor = torch.zeros(tokens_tensor.shape, dtype=torch.float32)
        segments.append(segments_tensor)
        labels.append(label+1)
    return tokens, segments, labels

In [5]:
class SST2_Dataset(BERT_Dataset):
    def __init__(self, src_path, tokenizer):
        super(SST2_Dataset, self).__init__(src_path, tokenizer)
        
    def load_data(self):
        self.raw_data = pd.read_csv(self.src_path, delimiter = "\t", header = 0)

    def preprocess_data(self):
        self.tokens, self.segments, self.labels = pre_process(self.raw_data)

In [10]:
def collate_fn(batch, max_len=-1):
    batch_size = len(batch)
    max_sent_len = int(np.max([x[0][0].shape[1] for x in batch]))
    if max_len > 0 and max_len < max_sent_len:
        max_sent_len = max_len

    idx_matrix = np.zeros((batch_size, max_sent_len), dtype=np.int)
    seg_matrix = np.zeros((batch_size, max_sent_len), dtype=np.float32)
    label_matrix = np.zeros((batch_size, 1)) # TODO: does this work on other tasks?

    for idx1 in np.arange(len(batch)):
        (tokens, segments), labels = batch[idx1]
        for idx2 in np.arange(len(tokens)):
            if idx2 >= max_sent_len:
                break
            import pdb; pdb.set_trace()
            idx_matrix[idx1, idx2] = tokens[idx2]
            seg_matrix[idx1, idx2] = segments[idx2]

        label_matrix[idx1] = labels

    idx_matrix = torch.LongTensor(idx_matrix)
    seg_matrix = torch.LongTensor(seg_matrix)
    mask_matrix = torch.eq(idx_matrix.data, -1).long()
    label_matrix = torch.FloatTensor(label_matrix)
    
    return (idx_matrix, seg_matrix, mask_matrix), label_matrix

In [17]:
dataset = {}
dataloaders = {}
for split in ['train', 'dev']:
    dataset[split] = SST2_Dataset(SST_root+split+'.tsv', tokenizer)
    dataset[split].load_data()
    dataset[split].preprocess_data()
    dataloaders[split] = DataLoader(dataset[split], batch_size=4)

In [14]:
class BertEncoder(nn.Module):
    def __init__(self):
        super(BertEncoder, self).__init__()
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
        
    def forward(self, data):
        tokens, segments, mask = data
        # TODO: check if we should return all layers or just last hidden representation 
        _, hidden_layer = self.bert_model(tokens, segments, mask)
        return hidden_layer

encoder_module = BertEncoder()
end_model = EndModel(
    [768, 1],
    input_module=encoder_module,
    seed=123,
    use_cuda=False,
    skip_head=False,
    input_relu=False,
    input_batchnorm=False,
    verbose=False
)

In [18]:
end_model.train_model(dataloaders['train'], valid_data=dataloaders['dev'],
                      lr=0.01, l2=0.01, 
                      batch_size=4, 
                      n_epochs=5, checkpoint_metric='valid/loss',
                      checkpoint_metric_mode='max',
                      verbose=True, progress_bar=True
                    )




  0%|          | 0/2105 [00:00<?, ?it/s][A[A[A

Could not find kwarg "valid_data" in destination dict.
Could not find kwarg "checkpoint_metric" in destination dict.
Could not find kwarg "checkpoint_metric_mode" in destination dict.
Could not find kwarg "progress_bar" in destination dict.


RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 8 and 9 in dimension 2 at /opt/conda/conda-bld/pytorch_1535491974311/work/aten/src/TH/generic/THTensorMath.cpp:3616