# Headers and Global Variables

In [23]:
import csv
import torch
from torch.nn.functional import softmax
from torch.nn.functional import relu
from transformers import BertForNextSentencePrediction, BertTokenizer
from IPython.display import display, HTML
import matplotlib

import os
import sys
from pathlib import Path

project_path = Path(os.path.dirname(os.path.realpath(sys.argv[0]))).parent
modelpath = str(project_path.joinpath('models')) + "/"
datapath = str(project_path.joinpath('datasets')) + "/"

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# underlying pretrained LM
BASE_MODEL = 'bert-large-uncased-whole-word-masking'


In [24]:
def getModelFileName(model_name, last_epoch):
    return modelpath + model_name + str(last_epoch)

# Datasets

In [19]:
class ClozeTest(torch.utils.data.Dataset):
    def __init__(self, dev=True, hypothesis_only=False, file = None):
        """
        :param hypothesis_only: Replaces story with empty string. Only Keeps endings as they are.
        :param dev: if dev=True, load dev set for testing, otherwise training set
        :param file: csv file to load the data from
        """

        dataset = []

        dir = ""

        # if dev=True, we load the dev set for testing
        if file is None:
          if dev:
              dir = datapath + 'cloze_test.csv'
          else:
              dir = datapath + 'cloze_train.csv'
        else: dir = datapath + file

        with open(dir, 'r', encoding='utf-8') as d:
            reader = csv.reader(d, quotechar='"', delimiter=',', 
                                quoting=csv.QUOTE_ALL, skipinitialspace=True)                
            for line in reader:
                dataset.append(line) 
            dataset.pop(0)

        self.data = []
        self.labels = []

        for sample in dataset:
            
            start = " ".join(sample[1:-3])
            if hypothesis_only: start = ""
            end1 = sample[-3]
            end2 = sample[-2]
            right_ending = sample[-1]

            self.data.append([start, end1])
            self.labels.append(0 if "1" == right_ending else 1)

            self.data.append([start, end2])
            self.labels.append(0 if "2" == right_ending else 1)

    def __getitem__(self, idx):
        X = self.data[idx]
        y = self.labels[idx]        
        return X, y

    def __len__(self):
        assert len(self.data) == len(self.labels)
        return len(self.labels)

# Saliency Maps

In [26]:
def saliency_map(model, tokenizer, input, ending, label):

    # Activations are saved.
    acts = dict() # one-key dictionary. Doesn't work otherwise.
    def get_acts(name):
        def hook(module, input, output):
                  acts[name] = output.detach()
        return hook
 
      # Gradients are saved.
    grads = dict() # same as for activations
    def get_grads(name):
            def hook(module, input, output):
                  grads[name] = output[0].detach() # 'output' is a tuple
            return hook
 
    frw_handle = model.bert.embeddings.register_forward_hook(get_acts("emb"))
    bck_handle = model.bert.embeddings.register_backward_hook(get_grads("emb"))
    
    tokens = tokenizer(input, text_pair=ending, return_tensors='pt').to(device)
    token_names = tokenizer.decode(tokens.input_ids[0])
 
 
    model.eval()
    model.zero_grad()
    model = model.to(device)
 
    logits = model(**tokens, labels=torch.tensor([label]).to(device)).logits.view(-1, 2)
    # Gradient of loss as per Han et al. 2020 calculated
    torch.nn.CrossEntropyLoss()(logits, logits.argmax(dim=-1)).backward()

    prediction = logits.argmax(dim=-1).int().item()


    frw_handle.remove()
    bck_handle.remove()


    saliencies = (-grads["emb"] * acts["emb"]).sum(dim=-1)
    norm = torch.linalg.norm(saliencies, ord=1, dim=-1, keepdims=True)
    saliencies = saliencies / norm # normalizing the saliencies
    saliencies = saliencies[0] # squeezing the batch of one
 
    print("Predicted Label: ", prediction, "\t")

    # Visualization. Courtesy of https://gist.github.com/ihsgnef
    colors = saliencies / max(abs(saliencies.min()), abs(saliencies.max())) * 0.5 + 0.5
    cmap = matplotlib.cm.get_cmap('RdBu')
    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    colored_string = ''
    for token_name, color in zip(token_names.split(), colors.tolist()):
        color = matplotlib.colors.rgb2hex(cmap(color)[:3])
        colored_string += template.format(color, '&nbsp' + token_name + '&nbsp')
    
    display(HTML(colored_string))

def print_salience_sample(testfiles=["cloze_test.csv"], size = 10):
"""
:param testfiles: A List of csv filenames of datasets to be compared side by side.
:param size: Nr of samples to be drawn from the datasets.
"""
    import random

    models = [BertForNextSentencePrediction.from_pretrained(BASE_MODEL),
        BertForNextSentencePrediction.from_pretrained(getModelFileName("bertfornsp_roc_finetuned", "1")),
        BertForNextSentencePrediction.from_pretrained(getModelFileName("bertfornsp_clozeonly_finetuned", "10")),
        BertForNextSentencePrediction.from_pretrained(getModelFileName("bertfornsp_mixed", "5"))]
    modelnames = ["BERT: ", "ROC Model: ", "CLOZE Model: ", "MIXED Model: "]
    tokenizer = BertTokenizer.from_pretrained(BASE_MODEL)

    cloze_tests = [ClozeTest(file=file, dev=True) for file in testfiles]
    for n in range(size):
        i = random.randrange(len(cloze_tests[-1].data)) #Example data point
        dps = [cloze_test[i] for cloze_test in cloze_tests] 
        for j, model in enumerate(models):
            print(modelnames[j])
            for dp in dps:
                story, label = dp
                input, ending = story
                print("True Label: ", label)
                saliency_map(model = model, tokenizer=tokenizer, input = input, ending = ending, label = label)
        print("\n")

In [27]:
#Random 5 Samples for noised dataset
print_salience_sample(testfiles = ["cloze_test.csv", "noise_test_set"], size = 5)

#Random 20 Samples in Appendix
print_salience_sample(size = 20)

Error: Pip module Unable to parse debugpy output, please log an issue with https://github.com/microsoft/vscode-jupyter is required for debugging cells. You will need to install it to debug cells.