In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../")

In [None]:
from IPython.core.display import HTML

In [None]:
import os
import numpy as np

import torch.nn.functional as F

from xbert.engine import Engine, weight_of_evidence, difference_of_log_probabilities, calculate_correlation
from xbert import InputInstance, Config
from xbert.visualization import visualize_relevances

In [None]:
from segtok.tokenizer import web_tokenizer

from transformers import RobertaTokenizer, RobertaForSequenceClassification  #, glue_convert_examples_to_features

In [None]:
CUDA_DEVICE = 0 # or -1 if no GPU is available

MODEL_NAME = "roberta-large-mnli"

In [None]:
tokenizer = RobertaTokenizer.from_pretrained(MODEL_NAME)
model = RobertaForSequenceClassification.from_pretrained(MODEL_NAME).to(CUDA_DEVICE)

In [None]:
MNLI_DATASET_PATH = "../data/glue_data/MNLI/"
MNLI_IDX2LABEL = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
MNLI_LABEL2IDX = {v: k for k, v in MNLI_IDX2LABEL.items()}

In [None]:
from typing import List, Tuple


def read_mnli_dataset(path: str) -> List[Tuple[List[str], List[str], str]]:
    dataset = []
    with open(path) as fin:
        fin.readline()
        for index, line in enumerate(fin):
            tokens = line.strip().split('\t')
            sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
            dataset.append((sent1, sent2, target))
            
    return dataset


def dataset_to_input_instances(dataset: List[Tuple[List[str], List[str], str]]) -> List[InputInstance]:
    input_instances = []
    for idx, (sent1, sent2, _) in enumerate(dataset):
        instance = InputInstance(id_=idx, sent1=web_tokenizer(sent1), sent2=web_tokenizer(sent2))
        input_instances.append(instance)
        
    return input_instances


def get_labels(dataset: List[Tuple[List[str], List[str], str]]) -> List[str]:
    return [label for _, _, label in dataset]

In [None]:
def predict(input_instance, model, tokenizer, cuda_device):
    input_ids = tokenizer.encode(text=input_instance.sent1.tokens,
                                 text_pair=input_instance.sent2.tokens,
                                 add_special_tokens=True,
                                 return_tensors="pt").to(cuda_device)
    
    logits = model(input_ids)[0]
    return F.softmax(logits, dim=-1)

In [None]:
dataset = read_mnli_dataset(os.path.join(MNLI_DATASET_PATH, "dev_matched.tsv"))
input_instances = dataset_to_input_instances(dataset)
labels = get_labels(dataset)

In [None]:
def batcher(batch_instances):
    true_label_indices = []
    probabilities = []
    for instance in batch_instances:
        idx = instance.id
        true_label_idx = MNLI_LABEL2IDX[labels[idx]]
        true_label_indices.append(true_label_idx)
        probs = predict(instance, model, tokenizer, CUDA_DEVICE)[0]
        probabilities.append(probs[true_label_idx].item())
    
    return probabilities
    

config_unk = Config.from_dict({
    "strategy": "unk_replacement",
    "batch_size": 128,
    "unk_token": "___UNK___"
})

config_resample = Config.from_dict({
    "strategy": "bert_lm_sampling",
    "cuda_device": 0,
    "bert_model": "bert-base-uncased",
    "batch_size": 128,
    "n_samples": 100,
    "verbose": False
})

unknown_engine = Engine(config_unk, batcher)
resample_engine = Engine(config_resample, batcher)

In [None]:
instance_idx = 0
n = 5

unk_occluded_instances, unk_instance_probabilities = unknown_engine.run(input_instances[instance_idx: instance_idx+n])
res_occluded_instances, res_instance_probabilities = resample_engine.run(input_instances[instance_idx: instance_idx+n])

In [None]:
unk_relevances_difference = unknown_engine.relevances(unk_occluded_instances, unk_instance_probabilities)
res_relevances_difference = resample_engine.relevances(res_occluded_instances, res_instance_probabilities)

In [None]:
labels_true = labels[instance_idx: instance_idx+n]
labels_pred = [MNLI_IDX2LABEL[predict(instance, model, tokenizer, CUDA_DEVICE)[0].argmax().item()] for instance in input_instances[instance_idx: instance_idx+n]]

In [None]:
HTML(visualize_relevances(input_instances[instance_idx: instance_idx+n], unk_relevances_difference, labels_true, labels_pred))

In [None]:
HTML(visualize_relevances(input_instances[instance_idx: instance_idx+n], res_relevances_difference, labels_true, labels_pred))

In [None]:
calculate_correlation(unk_relevances_difference, res_relevances_difference)