In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../")

In [None]:
from IPython.core.display import HTML

In [None]:
import numpy as np

from xbert_tasks.predictor_utils import load_predictor

from xbert_tasks.classification.models.text_classifier import TextClassifier
from xbert_tasks.classification.predictors.text_classifier_predictor import TextClassifierPredictor
from xbert_tasks.classification.dataset_readers.sst2_dataset_reader import Sst2DatasetReader

In [None]:
from xbert.occlusion import Engine, weight_of_evidence, difference_of_log_probabilities

In [None]:
def visualize_relevances(inputs, relevances, labels_true = None, labels_pred = None, font_size=5):
    def rgba(relevance):
        if relevance >= 0:
            return f"rgba(255, 0, 0, {relevance})"
        else:
            return f"rgba(0, 0, 255, {abs(relevance)})"
        
    def color(relevance):
        if relevance > 0.8:
            return "white"
        else:
            return "black"
        
    visualized_inputs = []
    for i, (input_id, tokens) in enumerate(inputs):
        tokens_relevance = relevances[input_id]
        max_relevance = max(np.abs(list(tokens_relevance.values())))
        norm_tokens_relevance = {idx: r / max_relevance for idx, r in tokens_relevance.items()}
        
        html_tokens = []
        for idx, token in enumerate(tokens):
            relevance = norm_tokens_relevance[idx]
            html_token = f'<span style="color:{color(relevance)}; background-color:{rgba(relevance)};">{token}</span>'
            html_tokens.append(html_token)
        
        if labels_true is not None:
            correct = ""
            if labels_pred is not None:
                correct = "&#10004;" if labels_true[i] == labels_pred[i] else "&#10006;"
                
            prefix = '<span style="color:black; background-color:rgba(255, 255, 0, 0.6);">' \
                + f'{labels_true[i]} {correct} {max_relevance:.2f}</span>:   '
        else:
            prefix = ""
            
        visualized_input = f'<font size="{font_size}">' + prefix + " ".join(html_tokens) + '</font>'
        
        visualized_inputs.append(visualized_input)
            
    return HTML("</br>".join(visualized_inputs))

In [None]:
CUDA_DEVICE = 0 # or -1 if no GPU is available

MODEL_DIR = "~/Downloads/xbert_sst2/"
PREDICTOR_NAME = "text_classifier"

In [None]:
predictor = load_predictor(MODEL_DIR, PREDICTOR_NAME, CUDA_DEVICE, archive_filename="model.tar.gz", weights_file=None)

In [None]:
SST_DATASET_PATH = "~/Downloads/SST-2/"

instances = predictor._dataset_reader.read(SST_DATASET_PATH + "dev.tsv")

In [None]:
def batcher(batch_candidates):
    label2idx = predictor._model.vocab.get_token_to_index_vocabulary("labels")
    
    true_label_indices = []
    batch_dicts = []
    for candidate in batch_candidates:
        idx = candidate.id
        true_label_idx = label2idx[instances[idx].fields["label"].label]
        true_label_indices.append(true_label_idx)
        batch_dicts.append(dict(text=candidate.tokens))
    
    results = predictor.predict_batch_json(batch_dicts)
    
    return [result["class_probabilities"][tl_idx] for (result, tl_idx) in zip(results, true_label_indices)]
    

params = {
    "cuda_device": 0,
    "bert_model": "bert-base-uncased",
    "batch_size": 128,
    "n_samples": 100,
    "verbose": False
}

engine = Engine(params, batcher)

In [None]:
#inputs = [(idx, [t.text for t in instance.fields["tokens"].tokens]) for idx, instance in enumerate(instances)]
#labels_true = [instance.fields["label"].label for instance in instances]
#labels_pred = [predictor.predict_instance(instance)["label"] for instance in instances]

In [None]:
instance_idx = 0
n = 100
inputs = [(idx, [t.text for t in instance.fields["tokens"].tokens]) for idx, instance in zip(range(instance_idx, instance_idx+n), instances[instance_idx: instance_idx+n])]
labels_true = [instance.fields["label"].label for instance in instances[instance_idx: instance_idx+n]]
labels_pred = [predictor.predict_instance(instance)["label"] for instance in instances[instance_idx: instance_idx+n]]

In [None]:
engine.run(inputs)

In [None]:
relevances_evidence = engine.relevances(scoring_method=weight_of_evidence)
relevances_difference = engine.relevances()
relevances_difference_log = engine.relevances(scoring_method=difference_of_log_probabilities)

In [None]:
visualize_relevances(inputs, relevances_evidence, labels_true, labels_pred)

In [None]:
visualize_relevances(inputs, relevances_difference, labels_true, labels_pred)

In [None]:
visualize_relevances(inputs, relevances_difference_log, labels_true, labels_pred)