[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googlecolab/colabtools/blob/master/notebooks/colab-github-demo.ipynb)

In [None]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ed/d5/f4157a376b8a79489a76ce6cfe147f4f3be1e029b7144fa7b8432e8acb26/transformers-4.4.2-py3-none-any.whl (2.0MB)
[K     |████████████████████████████████| 2.0MB 11.2MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 40.3MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/71/23/2ddc317b2121117bf34dd00f5b0de194158f2a44ee2bf5e47c7166878a97/tokenizers-0.10.1-cp37-cp37m-manylinux2010_x86_64.whl (3.2MB)
[K     |████████████████████████████████| 3.2MB 39.6MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp37-none-any.whl size=893262 sha256=1a6039ca95

# Headers and Global Variables

In [None]:
import csv
import torch
from torch.nn.functional import softmax
from transformers import BertForNextSentencePrediction, BertTokenizer
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

BASE_MODEL = 'bert-large-uncased-whole-word-masking'

device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = BertTokenizer.from_pretrained(BASE_MODEL)

TRIGGERS_ONLY_FILE = '../datasets/cloze_test_triggers_only.csv'
TRIGGERS_REMOVED_ONLY_FILE = '../datasets/cloze_test_triggers_removed_only.csv'
TRIGGERS_SYNONYMIZED_ONLY_FILE = '../datasets/cloze_test_triggers_synonymized_only.csv'

MODEL_CLOZE_FILE = '../models/bertfornsp_clozeonly_finetuned10'
MODEL_ROC_FILE = '../models/bertfornsp_roc_finetuned1'
MODEL_MIXED_FILE = '../models/bertfornsp_mixed5'

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




# Datasets

In [None]:
class ClozeTest(torch.utils.data.Dataset):
    def __init__(self, filename):

        dataset = []

        with open(current_directory + filename, 'r', encoding='utf-8') as d:
            reader = csv.reader(d, quotechar='"', delimiter=',', 
                                quoting=csv.QUOTE_ALL, skipinitialspace=True)                
            for line in reader:
                dataset.append(line) 
            dataset.pop(0)

        self.data = []
        self.labels = []

        for sample in dataset:
            
            start = " ".join(sample[1:-3])
            end1 = sample[-3]
            end2 = sample[-2]
            right_ending = sample[-1]

            self.data.append([start, end1])
            self.labels.append(0 if "1" == right_ending else 1)

            self.data.append([start, end2])
            self.labels.append(0 if "2" == right_ending else 1)

    def __getitem__(self, idx):
        X = self.data[idx]
        y = self.labels[idx]        
        return X, y

    def __len__(self):
        assert len(self.data) == len(self.labels)
        return len(self.labels)

In [None]:
class ClozeTest_MC(torch.utils.data.Dataset):
    def __init__(self, filename):
        
        dataset = []

        with open(current_directory + filename, 'r', encoding='utf-8') as d:
            reader = csv.reader(d, quotechar='"', delimiter=',', 
                                quoting=csv.QUOTE_ALL, skipinitialspace=True)                
            for line in reader:
                dataset.append(line) 
            dataset.pop(0)
 
        self.data = []
        self.labels = []
 
        for sample in dataset:
            
            start = " ".join(sample[1:-3])
            end1 = sample[-3]
            end2 = sample[-2]
            right_ending = sample[-1]
 
            self.data.append([start, end1, end2])
            self.labels.append(0 if "1" == right_ending else 1)
 
    def __getitem__(self, idx):
        X = self.data[idx]
        y = self.labels[idx]        
        return X, y
 
    def __len__(self):
        assert len(self.data) == len(self.labels)
        return len(self.labels)

In [None]:
triggers_only_set = ClozeTest(TRIGGERS_ONLY_FILE)
triggers_removed_only_set = ClozeTest(TRIGGERS_REMOVED_ONLY_FILE)
triggers_synonymized_only_set = ClozeTest(TRIGGERS_SYNONYMIZED_ONLY_FILE)

triggers_only_set_mc = ClozeTest_MC(TRIGGERS_ONLY_FILE)
triggers_removed_only_set_mc = ClozeTest_MC(TRIGGERS_REMOVED_ONLY_FILE)
triggers_synonymized_only_set_mc = ClozeTest_MC(TRIGGERS_SYNONYMIZED_ONLY_FILE)

# Functions for Testing

In [None]:
def test(model, dataset, verbose=False):
    softmax = torch.nn.Softmax(dim=1)

    #Send to GPU and allow Evaluation
    model = model.to(device)
    model.eval()

    #Dataloader
    devloader = torch.utils.data.DataLoader(dataset, batch_size=10)

    pred_list, label_list = list(), list()

    for stories, labels in devloader:
        
        start = stories[0]
        end = stories[1]
        
        # Tokenize sentence pairs.
        # All sequences in batch processing must be same length.
        # Therefore we use padding to fill shorter sequences
        # with uninterpreted [PAD] tokens)
        tokenized_batch = tokenizer(start, padding = True, text_pair = end,
                                    return_tensors='pt').to(device)

        #Send to GPU
        labels = labels.to(device)

        outputs = model(**tokenized_batch, labels = labels)
        logits = outputs.logits

        # Model predicts sentence-pair as correct if True-logit > False-logit
        predictions = logits.argmax(dim=1).int()
        probs = softmax(logits).cpu().detach()

        # Extra info print() if verbose
        if verbose:
            # iterate over elements in batch
            for i, element_input_ids in enumerate(tokenized_batch.input_ids):
                print(tokenizer.decode(element_input_ids))
                print("Probability:", probs[i][0].item() * 100)
                print("Predicted: ", bool(predictions[i]))
                print("True label: ", bool(labels[i]))

        pred_list.extend(predictions.tolist())
        label_list.extend(labels.tolist())

    #print(confusion_matrix(label_list, pred_list))
    print(classification_report(label_list, pred_list))

    #return confusion_matrix(label_list, pred_list).ravel()

In [None]:
def test_mc(model, dataset, verbose=False):
    softmax = torch.nn.Softmax(dim=1)
 
    #Send to GPU and allow Evaluation
    model = model.to(device)
    model.eval()
 
    #Dataloader
    devloader = torch.utils.data.DataLoader(dataset, batch_size=10)
 
    pred_list, label_list = list(), list()
 
    for stories, labels in devloader:
        
        start = stories[0]
        end1 = stories[1]
        end2 = stories[2]
 
        tokenized_batch_end1 = tokenizer(start, padding = True, text_pair = end1,
                                    return_tensors='pt').to(device)
        
        tokenized_batch_end2 = tokenizer(start, padding = True, text_pair = end2,
                                    return_tensors='pt').to(device) 
 
        #Send to GPU
        labels = labels.to(device)       
        
        logits0 = model(**tokenized_batch_end1).logits
        logits1 = model(**tokenized_batch_end2).logits    

        logits = logits0 + logits1.flip(-1)
        
        # Model predicts sentence-pair as correct if True-logit > False-logit
        predictions = logits.argmax(dim=1).int()
        #probs = softmax(logits).cpu().detach()        
        
        """
        predictions = []
 
        for i in range(len(labels.data)):
            end1_likelih = logits_end1.data[i][0]
            end2_likelih = logits_end2.data[i][0]
            likelihoods = torch.tensor([end1_likelih, end2_likelih])
            pred = likelihoods.argmax(dim=0).int()
            predictions.append(pred)
 
        """

        # Extra info print() if verbose
        if verbose:
            # iterate over elements in batch
            for i, element_input_ids in enumerate(tokenized_batch.input_ids):
                print(tokenizer.decode(element_input_ids))
                print("Probability:", probs[i][0].item() * 100)
                print("Predicted: ", bool(predictions[i]))
                print("True label: ", bool(labels[i]))
 
        pred_list.extend(predictions.tolist())
        label_list.extend(labels.tolist())
 
    #print(confusion_matrix(label_list, pred_list))

    print(classification_report(label_list, pred_list))

In [None]:
def test_model(model):
    print("Binary:")
    print("With triggers:")
    test(model, triggers_only_set)
    print("Triggers removed:")
    test(model, triggers_removed_only_set)
    print("Triggers synonymized:")
    test(model, triggers_synonymized_only_set)
    print()

In [None]:
def test_model_mc(model):
    print("Choice:")
    print("With triggers:")
    test_mc(model, triggers_only_set_mc)
    print("Triggers removed:")
    test_mc(model, triggers_removed_only_set_mc)
    print("Triggers synonymized:")
    test_mc(model, triggers_synonymized_only_set_mc)
    print()

# BERT

In [None]:
model = BertForNextSentencePrediction.from_pretrained(BASE_MODEL)
test_model(model)
test_model_mc(model)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=434.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1345000548.0, style=ProgressStyle(descr…




Some weights of the model checkpoint at bert-large-uncased-whole-word-masking were not used when initializing BertForNextSentencePrediction: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Binary:
With triggers:
              precision    recall  f1-score   support

           0       0.53      0.16      0.24       174
           1       0.51      0.86      0.64       174

    accuracy                           0.51       348
   macro avg       0.52      0.51      0.44       348
weighted avg       0.52      0.51      0.44       348

Triggers removed:
              precision    recall  f1-score   support

           0       0.61      0.17      0.27       174
           1       0.52      0.89      0.66       174

    accuracy                           0.53       348
   macro avg       0.57      0.53      0.46       348
weighted avg       0.57      0.53      0.46       348

Triggers synonymized:
              precision    recall  f1-score   support

           0       0.57      0.16      0.24       174
           1       0.51      0.89      0.65       174

    accuracy                           0.52       348
   macro avg       0.54      0.52      0.45       348
weighted av

# Cloze Only

In [None]:
model = BertForNextSentencePrediction.from_pretrained(MODEL_CLOZE_FILE)
test_model(model)
test_model_mc(model)

Binary:
With triggers:
              precision    recall  f1-score   support

           0       0.80      0.89      0.84       174
           1       0.88      0.78      0.82       174

    accuracy                           0.83       348
   macro avg       0.84      0.83      0.83       348
weighted avg       0.84      0.83      0.83       348

Triggers removed:
              precision    recall  f1-score   support

           0       0.79      0.82      0.81       174
           1       0.82      0.79      0.80       174

    accuracy                           0.80       348
   macro avg       0.80      0.80      0.80       348
weighted avg       0.80      0.80      0.80       348

Triggers synonymized:
              precision    recall  f1-score   support

           0       0.80      0.87      0.83       174
           1       0.86      0.79      0.82       174

    accuracy                           0.83       348
   macro avg       0.83      0.83      0.83       348
weighted av

# Roc Only

In [None]:
model = BertForNextSentencePrediction.from_pretrained(MODEL_ROC_FILE)
test_model(model)
test_model_mc(model)

Binary:
With triggers:
              precision    recall  f1-score   support

           0       0.53      0.99      0.69       174
           1       0.95      0.12      0.21       174

    accuracy                           0.56       348
   macro avg       0.74      0.56      0.45       348
weighted avg       0.74      0.56      0.45       348

Triggers removed:
              precision    recall  f1-score   support

           0       0.53      0.99      0.69       174
           1       0.91      0.11      0.20       174

    accuracy                           0.55       348
   macro avg       0.72      0.55      0.45       348
weighted avg       0.72      0.55      0.45       348

Triggers synonymized:
              precision    recall  f1-score   support

           0       0.53      0.99      0.69       174
           1       0.95      0.12      0.21       174

    accuracy                           0.56       348
   macro avg       0.74      0.56      0.45       348
weighted av

# Cloze with 5 000 Roc Stories mixed in

In [None]:
model = BertForNextSentencePrediction.from_pretrained(MODEL_MIXED_FILE)
test_model(model)
test_model_mc(model)

Binary:
With triggers:
              precision    recall  f1-score   support

           0       0.78      0.93      0.85       174
           1       0.91      0.74      0.81       174

    accuracy                           0.83       348
   macro avg       0.84      0.83      0.83       348
weighted avg       0.84      0.83      0.83       348

Triggers removed:
              precision    recall  f1-score   support

           0       0.80      0.90      0.85       174
           1       0.88      0.78      0.83       174

    accuracy                           0.84       348
   macro avg       0.84      0.84      0.84       348
weighted avg       0.84      0.84      0.84       348

Triggers synonymized:
              precision    recall  f1-score   support

           0       0.78      0.93      0.85       174
           1       0.91      0.74      0.81       174

    accuracy                           0.83       348
   macro avg       0.84      0.83      0.83       348
weighted av