In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../")

In [None]:
from IPython.core.display import HTML

In [None]:
import numpy as np

from xbert_tasks.predictor_utils import load_predictor

from xbert_tasks.classification.models.text_classifier import TextClassifier
from xbert_tasks.classification.predictors.text_classifier_predictor import TextClassifierPredictor
from xbert_tasks.classification.dataset_readers.sst2_dataset_reader import Sst2DatasetReader

In [None]:
from xbert.engine import Engine, weight_of_evidence, difference_of_log_probabilities, calculate_correlation
from xbert import InputInstance, Config
from xbert.visualization import visualize_relevances

In [None]:
CUDA_DEVICE = 0 # or -1 if no GPU is available

MODEL_DIR = "~/Downloads/xbert_sst2/"
PREDICTOR_NAME = "sst_text_classifier"

In [None]:
predictor = load_predictor(MODEL_DIR, PREDICTOR_NAME, CUDA_DEVICE, archive_filename="model.tar.gz", weights_file=None)

In [None]:
SST_DATASET_PATH = "~/Downloads/SST-2/"

dataset_instances = predictor._dataset_reader.read(SST_DATASET_PATH + "dev.tsv")

In [None]:
def batcher(batch_instances):
    label2idx = predictor._model.vocab.get_token_to_index_vocabulary("labels")
    
    true_label_indices = []
    batch_dicts = []
    for instance in batch_instances:
        idx = instance.id
        true_label_idx = label2idx[dataset_instances[idx].fields["label"].label]
        true_label_indices.append(true_label_idx)
        batch_dicts.append(dict(text=instance.text.tokens))
    
    results = predictor.predict_batch_json(batch_dicts)
    
    return [result["class_probabilities"][tl_idx] for (result, tl_idx) in zip(results, true_label_indices)]
    

config_unk = Config.from_dict({
    "strategy": "unk_replacement",
    "batch_size": 128,
    "unk_token": "___UNK___"
})

config_resample = Config.from_dict({
    "strategy": "bert_lm_sampling",
    "cuda_device": 0,
    "bert_model": "bert-base-uncased",
    "batch_size": 128,
    "n_samples": 100,
    "verbose": False
})

unknown_engine = Engine(config_unk, batcher)
resample_engine = Engine(config_resample, batcher)

In [None]:
instance_idx = 100
n = 100
input_instances = [InputInstance(id_=idx, text=[t.text for t in dataset_instance.fields["tokens"].tokens])
                   for idx, dataset_instance
                   in zip(range(instance_idx, instance_idx+n), dataset_instances[instance_idx: instance_idx+n])]
labels_true = [instance.fields["label"].label for instance in dataset_instances[instance_idx: instance_idx+n]]
labels_pred = [predictor.predict_instance(instance)["label"] for instance in dataset_instances[instance_idx: instance_idx+n]]

In [None]:
unk_occluded_instances, unk_instance_probabilities = unknown_engine.run(input_instances)
res_occluded_instances, res_instance_probabilities = resample_engine.run(input_instances)

In [None]:
unk_relevances_difference = unknown_engine.relevances(unk_occluded_instances, unk_instance_probabilities)
res_relevances_difference = resample_engine.relevances(res_occluded_instances, res_instance_probabilities)

In [None]:
HTML(visualize_relevances(input_instances, unk_relevances_difference, labels_true, labels_pred))

In [None]:
HTML(visualize_relevances(input_instances, res_relevances_difference, labels_true, labels_pred))

In [None]:
calculate_correlation(unk_relevances_difference, res_relevances_difference)