In [None]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/f9/54/5ca07ec9569d2f232f3166de5457b63943882f7950ddfcc887732fc7fb23/transformers-4.3.3-py3-none-any.whl (1.9MB)
[K     |████████████████████████████████| 1.9MB 10.8MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/71/23/2ddc317b2121117bf34dd00f5b0de194158f2a44ee2bf5e47c7166878a97/tokenizers-0.10.1-cp37-cp37m-manylinux2010_x86_64.whl (3.2MB)
[K     |████████████████████████████████| 3.2MB 35.7MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 34.0MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp37-none-any.whl size=893262 sha256=328ee4530f

In [None]:
from google.colab import drive

currentDirectory = '/content/drive/My Drive/FSem/'
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import csv
import torch
from torch.nn.functional import softmax
from transformers import BertForNextSentencePrediction, BertTokenizer
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from tqdm import tqdm
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup


device = "cuda:0" if torch.cuda.is_available() else "cpu"

# underlying pretrained LM
BASE_MODEL = 'bert-large-uncased-whole-word-masking'

BATCH_SIZE = 12
WARMUP_EPOCHS = 1
TRAIN_EPOCHS = 10
LAST_EPOCH = -1

In [None]:
class RocStories(torch.utils.data.Dataset):
    def __init__(self):    
        dataset = []       
        with open(currentDirectory + 'roc_stories.csv', 
                  'r', encoding='utf-8') as d:
            
            reader = csv.reader(d, quotechar='"', delimiter=',',
                                quoting=csv.QUOTE_ALL, skipinitialspace=True)                
            for line in reader:
                dataset.append(line)  

        self.data = []
        self.labels = []

        stories = []
        endings = []
        for sample in dataset:           
            start = " ".join(sample[2:-1])
            stories.append(start)            
            end = sample[-1]                        
            endings.append(end)

        from random import shuffle
        wrong_endings = endings.copy()
        shuffle(wrong_endings)

        assert len(stories) == len(endings)
        for i, story in enumerate(stories):
            
            #True Ending
            self.data.append([story, endings[i]])
            self.labels.append(0)

            #Wrong Ending
            self.data.append([story, wrong_endings[i]])
            self.labels.append(1)

    def __getitem__(self, idx):
        X = self.data[idx]
        y = self.labels[idx]        
        return X, y

    def __len__(self):
        assert len(self.data) == len(self.labels)
        return len(self.labels)

In [None]:
class ClozeTest(torch.utils.data.Dataset):
    def __init__(self, dev=True):
        
        dataset = []

        # if dev=True, we load the dev set for testing
        with open(currentDirectory + 'cloze_test.csv' if dev else 
                  currentDirectory + 'cloze_train.csv', 
                  'r', encoding='utf-8') as d:
            reader = csv.reader(d, quotechar='"', delimiter=',', 
                                quoting=csv.QUOTE_ALL, skipinitialspace=True)                
            for line in reader:
                dataset.append(line) 

        self.data = []
        self.labels = []

        for sample in dataset:
            
            start = " ".join(sample[1:-3])
            end1 = sample[-3]
            end2 = sample[-2]
            right_ending = sample[-1]

            self.data.append([start, end1])
            self.labels.append(0 if "1" == right_ending else 1)

            self.data.append([start, end2])
            self.labels.append(0 if "2" == right_ending else 1)

    def __getitem__(self, idx):
        X = self.data[idx]
        y = self.labels[idx]        
        return X, y

    def __len__(self):
        assert len(self.data) == len(self.labels)
        return len(self.labels)


In [None]:
def train(model_file=BASE_MODEL, batch_size=BATCH_SIZE,
          warmup_epochs=WARMUP_EPOCHS, train_epochs=TRAIN_EPOCHS,
          last_epoch=LAST_EPOCH, cloze_test=True):
    
    tokenizer = BertTokenizer.from_pretrained(BASE_MODEL)
    model = BertForNextSentencePrediction.from_pretrained(model_file)

    #Send to GPU and allow Training
    model = model.to(device)
    model.train()

    if cloze_test:
        trainloader = torch.utils.data.DataLoader(ClozeTest(dev=False),
                                                  batch_size=batch_size,
                                                  shuffle=True)
    else:
        trainloader = torch.utils.data.DataLoader(RocStories(),
                                                  batch_size=batch_size,
                                                  shuffle=True)

    #LR maybe needs to be optimized
    optimizer = AdamW(model.parameters(), lr=1e-5)
    n_batches =  len(trainloader)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=(warmup_epochs * n_batches),
        num_training_steps=(train_epochs * n_batches),
        last_epoch=(last_epoch * n_batches)
    )

    for epoch in tqdm(range(train_epochs)):
        
        for stories, labels in trainloader:
            # this is PyTorch-specific as gradients get accumulated        
            optimizer.zero_grad()

            start = stories[0]
            end = stories[1]

            labels = labels.to(device)
           
            # Tokenize sentence pairs.
            # All sequences in batch processing must be same length.
            # Therefore we use padding to fill shorter sequences
            # with uninterpreted [PAD] tokens)
            tokenized_batch = tokenizer(start, padding = True, text_pair = end,
                                        return_tensors='pt').to(device)
            
            outputs = model(**tokenized_batch, labels = labels)           
            outputs.loss.backward()
            optimizer.step()
            scheduler.step() # Huggingface specific: step = epoch

        if cloze_test:
            model.save_pretrained(currentDirectory + 
                                  "bertfornsp_finetuned" + str(epoch))
        else:
            model.save_pretrained(currentDirectory +
                                  "bertfornsp_roc_finetuned" + str(epoch))

In [None]:
def test(model_file=BASE_MODEL, verbose = False):
    softmax = torch.nn.Softmax(dim=1)
    tokenizer = BertTokenizer.from_pretrained(BASE_MODEL)
    model = BertForNextSentencePrediction.from_pretrained(model_file)

    #Send to GPU and allow Evaluation
    model = model.to(device)
    model.eval()

    #Dataloader
    devloader = torch.utils.data.DataLoader(ClozeTest(), batch_size=10)

    pred_list, label_list = list(), list()

    for stories, labels in tqdm(devloader, disable=verbose):
        
        start = stories[0]
        end = stories[1]
        
         # Tokenize sentence pairs.
            # All sequences in batch processing must be same length.
            # Therefore we use padding to fill shorter sequences
            # with uninterpreted [PAD] tokens)
        tokenized_batch = tokenizer(start, padding = True, text_pair = end,
                                    return_tensors='pt').to(device)

        #Send to GPU
        labels = labels.to(device)

        outputs = model(**tokenized_batch, labels = labels)
        logits = outputs.logits

        # Model predicts sentence-pair as correct if True-logit > False-logit
        predictions = logits.argmax(dim=1).int()
        probs = softmax(logits).cpu().detach()

        # iterate over elements in batch
        for i, element_input_ids in enumerate(tokenized_batch.input_ids):
            
            # Extra info print() if verbose
            if verbose:
                print(tokenizer.decode(element_input_ids))
                print("Probability:", probs[i][0].item() * 100)
                print("Predicted: ", bool(predictions[i]))
                print("True label: ", bool(labels[i]))

        pred_list.extend(predictions.tolist())
        label_list.extend(labels.tolist())

    print(confusion_matrix(label_list, pred_list))
    print(classification_report(label_list, pred_list))

In [None]:
if __name__ == "__main__":
    #Test pretrained model
    test()

    #Fine-tune model "bertfornsp_finetuned"
    #train(cloze_test=True)

    #Test model "bertfornsp_finetuned"
    #test(model_file="bertfornsp_finetuned")

    #Fine-tune model "bertfornsp_roc_finetuned"
    train(cloze_test=False)

    #Test model "bertfornsp_finetuned"
    test(model_file=(currentDirectory + "bertfornsp_roc_finetuned"))

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking were not used when initializing BertForNextSentencePrediction: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 11/11 [29:05<00:00, 158.67s/it]
100%|██████████| 375/375 [00:52<00:00,  7.09it/s]

[[1578  293]
 [ 353 1520]]
              precision    recall  f1-score   support

           0       0.82      0.84      0.83      1871
           1       0.84      0.81      0.82      1873

    accuracy                           0.83      3744
   macro avg       0.83      0.83      0.83      3744
weighted avg       0.83      0.83      0.83      3744




