[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googlecolab/colabtools/blob/master/notebooks/colab-github-demo.ipynb)

In [None]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ed/d5/f4157a376b8a79489a76ce6cfe147f4f3be1e029b7144fa7b8432e8acb26/transformers-4.4.2-py3-none-any.whl (2.0MB)
[K     |████████████████████████████████| 2.0MB 5.8MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/71/23/2ddc317b2121117bf34dd00f5b0de194158f2a44ee2bf5e47c7166878a97/tokenizers-0.10.1-cp37-cp37m-manylinux2010_x86_64.whl (3.2MB)
[K     |████████████████████████████████| 3.2MB 25.0MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/08/cd/342e584ee544d044fb573ae697404ce22ede086c9e87ce5960772084cad0/sacremoses-0.0.44.tar.gz (862kB)
[K     |████████████████████████████████| 870kB 38.8MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.44-cp37-none-any.whl size=886084 sha256=26837

# Headers and Global Variables

In [None]:
import csv
import torch
from torch.nn.functional import softmax
from torch.nn.functional import relu
from transformers import BertForNextSentencePrediction, BertTokenizer
from tqdm import tqdm
import matplotlib
from matplotlib import pyplot as plt
from IPython.display import display, HTML

tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking')

TRIGGERS_ONLY_FILE = '../datasets/cloze_test_triggers_only.csv'
TRIGGERS_REMOVED_ONLY_FILE = '../datasets/cloze_test_triggers_removed_only.csv'
TRIGGERS_SYNONYMIZED_ONLY_FILE = '../datasets/cloze_test_triggers_synonymized_only.csv'

MODEL_CLOZE_FILE = '../models/bertfornsp_clozeonly_finetuned10'
MODEL_ROC_FILE = '../models/bertfornsp_roc_finetuned1'
MODEL_MIXED_FILE = '../models/bertfornsp_mixed5'

examples = {'instead': 0, 'ever': 17, 'anymore': 26, 'too': 8, 'eventually': 30, 'immediately': 42, 'anyway': 5, 'soon': 25, 'later': 38, 'now': 20, 'finally': 3} # the examples were manually picked from the triggers only dataset

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




# Datasets

In [None]:
class ClozeTest(torch.utils.data.Dataset):
    def __init__(self, filename):
        
        dataset = []

        with open(filename, 'r', encoding='utf-8') as d:
            reader = csv.reader(d, quotechar='"', delimiter=',', 
                                quoting=csv.QUOTE_ALL, skipinitialspace=True)                
            for line in reader:
                dataset.append(line)
            dataset.pop(0)

        self.data = []
        self.labels = []

        for sample in dataset:
            
            start = " ".join(sample[1:-3])
            end1 = sample[-3]
            end2 = sample[-2]
            right_ending = sample[-1]

            self.data.append([start, end1])
            self.labels.append(0 if "1" == right_ending else 1)

            self.data.append([start, end2])
            self.labels.append(0 if "2" == right_ending else 1)

    def __getitem__(self, idx):
        X = self.data[idx]
        y = self.labels[idx]        
        return X, y

    def __len__(self):
        assert len(self.data) == len(self.labels)
        return len(self.labels)

In [None]:
triggers_only_set = ClozeTest(TRIGGERS_ONLY_FILE)
triggers_removed_only_set = ClozeTest(TRIGGERS_REMOVED_ONLY_FILE)
triggers_synonymized_only_set = ClozeTest(TRIGGERS_SYNONYMIZED_ONLY_FILE)

# Saliency Map

In [None]:
def saliency_map(model, tokenizer, input, ending, label, first_row):
    # Activations are saved.
    acts = dict() # one-key dictionary. Doesn't work otherwise.
    def get_acts(name):
        def hook(module, input, output):
            acts[name] = output.detach()
        return hook

    # Gradients are saved.
    grads = dict() # same as for activations
    def get_grads(name):
        def hook(module, input, output):
            grads[name] = output[0].detach() # 'output' is a tuple
        return hook

    frw_handle = model.bert.embeddings.register_forward_hook(get_acts("emb"))
    bck_handle = model.bert.embeddings.register_backward_hook(get_grads("emb"))

    tokens = tokenizer(input, text_pair=ending, return_tensors='pt')
    token_names = tokenizer.tokenize(tokenizer.decode(tokens.input_ids[0]))

    model.eval()
    model.zero_grad()

    logits = model(**tokens, labels=torch.tensor([label])).logits.view(-1, 2)
    prediction = logits.argmax(dim=-1)
    # Gradient of loss as per Han et al. 2020 calculated
    torch.nn.CrossEntropyLoss()(logits, prediction).backward()

    frw_handle.remove()
    bck_handle.remove()

    saliencies = (-grads["emb"] * acts["emb"]).sum(dim=-1)
    norm = torch.linalg.norm(saliencies, ord=1, dim=-1, keepdims=True)
    saliencies = saliencies / norm # normalizing the saliencies
    saliencies = saliencies[0] # squeezing the batch of one

    # Visualization. Courtesy of https://gist.github.com/ihsgnef
    colors = saliencies / max(abs(saliencies.min()), abs(saliencies.max())) * 0.5 + 0.5
    cmap = matplotlib.cm.get_cmap('RdBu')
    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    colored_string = ''
    for token_name, color in zip(token_names, colors.tolist()):
        color = matplotlib.colors.rgb2hex(cmap(color)[:3])
        colored_string += template.format(color, '&nbsp' + token_name + '&nbsp')
    
    colored_string = first_row + "\tlabel:" + str(label) + "\tprediction:" + str(prediction[0].item()) + "\t" + colored_string
    display(HTML(colored_string))

# Differences visualizer

In [None]:
def visualize_for_model(model):
    for trigger, sentence_id in examples.items():
        print("sentences with " + trigger + ":")
        story, label = triggers_only_set[sentence_id]
        saliency_map(model, tokenizer, story[0], story[1], label, "with triggers")
        story, label = triggers_removed_only_set[sentence_id]
        saliency_map(model, tokenizer, story[0], story[1], label, "triggers removed")
        story, label = triggers_synonymized_only_set[sentence_id]
        saliency_map(model, tokenizer, story[0], story[1], label, "triggers synonymized")
        print()

# Cloze Only

In [None]:
model = BertForNextSentencePrediction.from_pretrained(MODEL_CLOZE_FILE)
visualize_for_model(model)

sentences with instead:





sentences with ever:



sentences with anymore:



sentences with too:



sentences with eventually:



sentences with immediately:



sentences with anyway:



sentences with soon:



sentences with later:



sentences with now:



sentences with finally:





# Roc Only

In [None]:
model = BertForNextSentencePrediction.from_pretrained(MODEL_ROC_FILE)
visualize_for_model(model)

sentences with instead:





sentences with ever:



sentences with anymore:



sentences with too:



sentences with eventually:



sentences with immediately:



sentences with anyway:



sentences with soon:



sentences with later:



sentences with now:



sentences with finally:





# Cloze with 5 000 Roc Stories mixed in

In [None]:
model = BertForNextSentencePrediction.from_pretrained(MODEL_MIXED_FILE)
visualize_for_model(model)

sentences with instead:





sentences with ever:



sentences with anymore:



sentences with too:



sentences with eventually:



sentences with immediately:



sentences with anyway:



sentences with soon:



sentences with later:



sentences with now:



sentences with finally:



